Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change representation of PrivateKey to not use String #452

Merged
merged 4 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions core/src/main/scala/dev/profunktor/auth/jwt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package dev.profunktor.auth

import java.security.PrivateKey
import java.nio.charset.StandardCharsets
import cats.*
import cats.syntax.all.*
import pdi.jwt.*
Expand All @@ -25,17 +27,52 @@ object jwt {

case class JwtToken(value: String) extends AnyVal

case class JwtSecretKey(value: String) extends AnyVal
object JwtSecretKey {
def apply(key: Array[Byte]): JwtSecretKey = new JwtSecretKeyByteArr(key)
def apply(key: Array[Char]): JwtSecretKey = new JwtSecretKeyCharArr(key)
def apply(key: PrivateKey): JwtSecretKey = new JwtSecretKeyPK(key)
}
sealed trait JwtSecretKey {
froth marked this conversation as resolved.
Show resolved Hide resolved
def value: Array[Char]
}
private class JwtSecretKeyCharArr(val value: Array[Char]) extends JwtSecretKey
private class JwtSecretKeyByteArr(bytes: Array[Byte]) extends JwtSecretKey {
lazy val value = {
val byteBuffer = java.nio.ByteBuffer.wrap(bytes)
val charBuffer = StandardCharsets.UTF_8.decode(byteBuffer)
val charArray = new Array[Char](charBuffer.remaining())
charBuffer.get(charArray)
charArray
}
}
private class JwtSecretKeyPK(key: PrivateKey) extends JwtSecretKey {
lazy val value = {
val byteBuffer = java.nio.ByteBuffer.wrap(key.getEncoded())
val charBuffer = StandardCharsets.UTF_8.decode(byteBuffer)
val charArray = new Array[Char](charBuffer.remaining())
charBuffer.get(charArray)
charArray
}
}

sealed trait JwtAuth
case object JwtNoValidation extends JwtAuth
case class JwtSymmetricAuth(secretKey: JwtSecretKey, jwtAlgorithms: Seq[JwtHmacAlgorithm]) extends JwtAuth
case class JwtAsymmetricAuth(publicKey: JwtPublicKey) extends JwtAuth
object JwtAuth {
def noValidation: JwtAuth = JwtNoValidation
@deprecated(message = "use of string to hold secret keys is deprecated", since = "1.x")
hunterpayne marked this conversation as resolved.
Show resolved Hide resolved
def hmac(secretKey: String, algorithm: JwtHmacAlgorithm): JwtSymmetricAuth =
JwtSymmetricAuth(JwtSecretKey(secretKey.toArray[Char]), Seq(algorithm))
def hmac(secretKey: Array[Char], algorithm: JwtHmacAlgorithm): JwtSymmetricAuth =
JwtSymmetricAuth(JwtSecretKey(secretKey), Seq(algorithm))
@deprecated(message = "use of string to hold secret keys is deprecated", since = "1.x")
froth marked this conversation as resolved.
Show resolved Hide resolved
def hmac(secretKey: String, algorithms: Seq[JwtHmacAlgorithm] = JwtAlgorithm.allHmac()): JwtSymmetricAuth =
JwtSymmetricAuth(JwtSecretKey(secretKey.toArray[Char]), algorithms)
def hmac(
secretKey: Array[Char],
algorithms: Seq[JwtHmacAlgorithm] /* = JwtAlgorithm.allHmac() */
froth marked this conversation as resolved.
Show resolved Hide resolved
): JwtSymmetricAuth =
JwtSymmetricAuth(JwtSecretKey(secretKey), algorithms)
}

Expand All @@ -47,7 +84,7 @@ object jwt {
): F[JwtClaim] =
(jwtAuth match {
case JwtNoValidation => Jwt.decode(jwtToken.value, JwtOptions.DEFAULT.copy(signature = false))
case JwtSymmetricAuth(secretKey, algorithms) => Jwt.decode(jwtToken.value, secretKey.value, algorithms)
case JwtSymmetricAuth(secretKey, algorithms) => Jwt.decode(jwtToken.value, secretKey.value.mkString, algorithms)
case JwtAsymmetricAuth(publicKey) => Jwt.decode(jwtToken.value, publicKey.key, publicKey.algorithm)
}).liftTo[F]

Expand All @@ -56,7 +93,7 @@ object jwt {
jwtSecretKey: JwtSecretKey,
jwtAlgorithm: JwtHmacAlgorithm
): F[JwtToken] =
JwtToken(Jwt.encode(jwtClaim, jwtSecretKey.value, jwtAlgorithm)).pure[F]
JwtToken(Jwt.encode(jwtClaim, jwtSecretKey.value.mkString, jwtAlgorithm)).pure[F]

def jwtEncode[F[_]](jwtClaim: JwtClaim, jwtPrivateKey: JwtPrivateKey)(implicit
F: ApplicativeError[F, Throwable]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ trait JwtFixture {
if (extractId(claim.content) == 123L) AuthUser(123L, "joe").some.pure[IO]
else none[AuthUser].pure[IO]

val jwtAuth = JwtAuth.hmac("53cr3t", JwtAlgorithm.HS256)
val jwtAuth = JwtAuth.hmac("53cr3t".toArray[Char], JwtAlgorithm.HS256)
val middleware = JwtAuthMiddleware[IO, AuthUser](jwtAuth, authenticate)

val adminToken = Jwt.encode(JwtClaim("{123}"), jwtAuth.secretKey.value, jwtAuth.jwtAlgorithms.head)
val noUserToken = Jwt.encode(JwtClaim("{666}"), jwtAuth.secretKey.value, jwtAuth.jwtAlgorithms.head)
val adminToken = Jwt.encode(JwtClaim("{123}"), jwtAuth.secretKey.value.mkString, jwtAuth.jwtAlgorithms.head)
val noUserToken = Jwt.encode(JwtClaim("{666}"), jwtAuth.secretKey.value.mkString, jwtAuth.jwtAlgorithms.head)
val randomToken = Jwt.encode(JwtClaim("{000}"), "secret", jwtAuth.jwtAlgorithms.head)

val rootReq = Request[IO](Method.GET, Uri.unsafeFromString("/"))
Expand Down
Loading