Skip to content

Commit

Permalink
Merge branch 'master' into renovate/sttpversion
Browse files Browse the repository at this point in the history
  • Loading branch information
dmivankov authored Sep 12, 2023
2 parents cbe6244 + 85a3a8a commit 46bd25c
Show file tree
Hide file tree
Showing 27 changed files with 229 additions and 2,631 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/scala.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ jobs:
# Caching dependencies in Pull Requests based on branch name and build.sbt.
# Can we do something better here?
- name: Cache Coursier dependencies
uses: actions/cache@v1
uses: actions/cache@v3
env:
cache-name: coursier-cache
with:
path: ~/.cache/coursier/v1
key: ${{ runner.os }}-build-${{ env.cache-name }}-${{ github.head_ref }}-${{ hashFiles('**/build.sbt') }}

- name: Cache Ivy 2 cache
uses: actions/cache@v1
uses: actions/cache@v3
env:
cache-name: sbt-ivy2-cache
with:
Expand Down Expand Up @@ -68,15 +68,15 @@ jobs:
# Caching dependencies in Pull Requests based on branch name and build.sbt.
# Can we do something better here?
- name: Cache Coursier dependencies
uses: actions/cache@v1
uses: actions/cache@v3
env:
cache-name: coursier-cache
with:
path: ~/.cache/coursier/v1
key: ${{ runner.os }}-build-${{ env.cache-name }}-${{ github.head_ref }}-${{ hashFiles('**/build.sbt') }}

- name: Cache Ivy 2 cache
uses: actions/cache@v1
uses: actions/cache@v3
env:
cache-name: sbt-ivy2-cache
with:
Expand Down Expand Up @@ -108,15 +108,15 @@ jobs:
# Caching dependencies in Pull Requests based on branch name and build.sbt.
# Can we do something better here?
- name: Cache Coursier dependencies
uses: actions/cache@v1
uses: actions/cache@v3
env:
cache-name: coursier-cache
with:
path: ~/.cache/coursier/v1
key: ${{ runner.os }}-build-${{ env.cache-name }}-${{ github.head_ref }}-${{ hashFiles('**/build.sbt') }}

- name: Cache Ivy 2 cache
uses: actions/cache@v1
uses: actions/cache@v3
env:
cache-name: sbt-ivy2-cache
with:
Expand All @@ -135,7 +135,7 @@ jobs:
cat /dev/null | sbt -Dsbt.log.noformat=true -J-Xmx2G -J-XX:+UseG1GC ++2.13 scalastyle test:scalastyle scalafmtCheck coverage test coverageReport
- name: Upload test coverage report
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.codecov_token }}

Expand Down
9 changes: 5 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ val scala212 = "2.12.17"
val supportedScalaVersions = List(scala212, scala213)

// This is used only for tests.
val jettyTestVersion = "9.4.48.v20220622"
val jettyTestVersion = "9.4.52.v20230823"

val sttpVersion = "3.9.0"
val circeVersion = "0.14.5"
Expand All @@ -17,7 +17,7 @@ val natchezVersion = "0.3.1"

lazy val gpgPass = Option(System.getenv("GPG_KEY_PASSWORD"))

ThisBuild / scalafixDependencies += "org.typelevel" %% "typelevel-scalafix" % "0.1.5"
ThisBuild / scalafixDependencies += "org.typelevel" %% "typelevel-scalafix" % "0.2.0"

lazy val patchVersion = scala.io.Source.fromFile("patch_version.txt").mkString.trim

Expand All @@ -28,6 +28,7 @@ lazy val commonSettings = Seq(
organizationHomepage := Some(url("https://cognite.com")),
version := "2.7." + patchVersion,
isSnapshot := patchVersion.endsWith("-SNAPSHOT"),
scalaVersion := scala213, // use 2.13 by default
crossScalaVersions := supportedScalaVersions,
semanticdbEnabled := true,
semanticdbVersion := scalafixSemanticdb.revision,
Expand Down Expand Up @@ -107,7 +108,7 @@ lazy val core = (project in file("."))
"org.typelevel" %% "cats-effect-testkit" % catsEffectVersion % Test,
"co.fs2" %% "fs2-core" % fs2Version,
"co.fs2" %% "fs2-io" % fs2Version,
"com.google.protobuf" % "protobuf-java" % "3.21.4",
"com.google.protobuf" % "protobuf-java" % "3.24.3",
"org.tpolecat" %% "natchez-core" % natchezVersion
) ++ scalaTestDeps ++ sttpDeps ++ circeDeps(CrossVersion.partialVersion(scalaVersion.value)),
scalacOptions ++= (CrossVersion.partialVersion(scalaVersion.value) match {
Expand Down Expand Up @@ -138,7 +139,7 @@ lazy val core = (project in file("."))
)

val scalaTestDeps = Seq(
"org.scalatest" %% "scalatest" % "3.2.14" % "test"
"org.scalatest" %% "scalatest" % "3.2.17" % "test"
)
val sttpDeps = Seq(
"com.softwaremill.sttp.client3" %% "core" % sttpVersion,
Expand Down
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version = 1.7.1
sbt.version = 1.9.4
6 changes: 3 additions & 3 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.10.3")
addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.11.1")
addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.11.0")
addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.9.21")
addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.1.2-1")
addSbtPlugin("io.github.davidgregory084" % "sbt-tpolecat" % "0.4.1")
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6")
addSbtPlugin("io.github.davidgregory084" % "sbt-tpolecat" % "0.4.4")
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.1")
addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0")
addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.0")
addSbtPlugin("au.com.onegeek" %% "sbt-dotenv" % "2.1.233")
Expand Down
72 changes: 24 additions & 48 deletions src/main/scala/com/cognite/sdk/scala/common/OAuth2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object OAuth2 {
cdfProjectName: String,
audience: Option[String] = None
) {
def getAuth[F[_]](refreshSecondsBeforeExpiration: Long = 30)(
def getAuth[F[_]](
implicit F: Async[F],
clock: Clock[F],
sttpBackend: SttpBackend[F, Any]
Expand All @@ -44,6 +44,7 @@ object OAuth2 {
) // Send empty audience when it is not provided
)
for {
acquiredLowerBound <- clock.realTime.map(_.toSeconds)
response <- basicRequest
.header("Accept", "application/json")
.post(tokenUri)
Expand All @@ -70,8 +71,7 @@ object OAuth2 {
)
}
}
acquiredAt <- clock.realTime.map(_.toSeconds)
expiresAt = acquiredAt + payload.expires_in - refreshSecondsBeforeExpiration
expiresAt = acquiredLowerBound + payload.expires_in
} yield TokenState(payload.access_token, expiresAt, cdfProjectName)
}
}
Expand All @@ -97,7 +97,6 @@ object OAuth2 {
}

def getAuth[F[_]](
refreshSecondsBeforeExpiration: Long = 30,
getToken: Option[F[String]] = None
)(
implicit F: Async[F],
Expand All @@ -108,6 +107,7 @@ object OAuth2 {
val uri = uri"${baseUrl}/api/v1/projects/${cdfProjectName}/sessions/token"
for {
kubernetesServiceToken <- getToken.getOrElse(getKubernetesJwt)
acquiredLowerBound <- clock.realTime.map(_.toSeconds)
payload <- basicRequest
.header("Accept", "application/json")
.header("Authorization", s"Bearer ${kubernetesServiceToken}")
Expand All @@ -118,66 +118,56 @@ object OAuth2 {
)
.send(sttpBackend)
.map(_.body)
acquiredAt <- clock.realTime.map(_.toSeconds)
expiresAt = acquiredAt + payload.expiresIn - refreshSecondsBeforeExpiration
expiresAt = acquiredLowerBound + payload.expiresIn
} yield TokenState(payload.accessToken, expiresAt, cdfProjectName)
}
}

private def commonGetAuth[F[_]](cache: CachedResource[F, TokenState])(
private def commonGetAuth[F[_]](
cache: CachedResource[F, TokenState],
refreshSecondsBeforeExpiration: Long
)(
implicit F: Monad[F],
clock: Clock[F]
): F[Auth] =
for {
now <- clock.realTime.map(_.toSeconds)
_ <- cache.invalidateIfNeeded(_.expiresAt <= now)
_ <- cache.invalidateIfNeeded(_.expiresAt - refreshSecondsBeforeExpiration <= now)
auth <- cache.run(state => F.pure(OidcTokenAuth(state.token, state.cdfProjectName)))
} yield auth

class ClientCredentialsProvider[F[_]] private (
cache: CachedResource[F, TokenState]
val cache: CachedResource[F, TokenState],
val refreshSecondsBeforeExpiration: Long
)(
implicit F: Monad[F],
clock: Clock[F]
) extends AuthProvider[F]
with Serializable {
def getAuth: F[Auth] = commonGetAuth(cache)
def getAuth: F[Auth] = commonGetAuth(cache, refreshSecondsBeforeExpiration)
}

class SessionProvider[F[_]] private (
cache: CachedResource[F, TokenState]
val cache: CachedResource[F, TokenState],
val refreshSecondsBeforeExpiration: Long
)(implicit F: Monad[F], clock: Clock[F])
extends AuthProvider[F]
with Serializable {
def getAuth: F[Auth] = commonGetAuth(cache)
def getAuth: F[Auth] = commonGetAuth(cache, refreshSecondsBeforeExpiration)
}

// scalastyle:off method.length
object ClientCredentialsProvider {
def apply[F[_]](
credentials: ClientCredentials,
refreshSecondsBeforeExpiration: Long = 30,
maybeCacheToken: Option[F[Option[TokenState]]] =
None // can't use F.pure(None) here hence extra Option[] wrapper
initialToken: Option[TokenState] = None
)(
implicit F: Async[F],
clock: Clock[F],
sttpBackend: SttpBackend[F, Any]
): F[ClientCredentialsProvider[F]] = {
val authenticate: F[TokenState] =
for {
now <- clock.realTime.map(_.toSeconds)
maybeTokenInCache <- maybeCacheToken.getOrElse(F.pure(None))
newToken <- maybeTokenInCache match {
case Some(originalToken)
if now < (originalToken.expiresAt - refreshSecondsBeforeExpiration) =>
F.delay(originalToken)
case _ =>
credentials.getAuth()
}
} yield newToken
ConcurrentCachedObject(authenticate).map(new ClientCredentialsProvider[F](_))
}
): F[ClientCredentialsProvider[F]] =
ConcurrentCachedObject(credentials.getAuth, initialToken)
.map(new ClientCredentialsProvider[F](_, refreshSecondsBeforeExpiration))
}
// scalastyle:on method.length

Expand All @@ -186,27 +176,13 @@ object OAuth2 {
session: Session,
refreshSecondsBeforeExpiration: Long = 30,
getToken: Option[F[String]] = None,
maybeCacheToken: Option[F[Option[TokenState]]] =
None // can't use F.pure(None) here hence extra Option[] wrapper
initialToken: Option[TokenState] = None
)(
implicit F: Async[F],
clock: Clock[F],
sttpBackend: SttpBackend[F, Any]
): F[SessionProvider[F]] = {
val authenticate: F[TokenState] =
for {
now <- clock.realTime.map(_.toSeconds)
maybeTokenInCache <- maybeCacheToken.getOrElse(F.pure(None))
newToken <- maybeTokenInCache match {
case Some(originalToken)
if now < (originalToken.expiresAt - refreshSecondsBeforeExpiration) =>
F.delay(originalToken)
case _ =>
session.getAuth(getToken = getToken)
}
} yield newToken
ConcurrentCachedObject(authenticate).map(new SessionProvider[F](_))
}
): F[SessionProvider[F]] =
ConcurrentCachedObject(session.getAuth(getToken = getToken), initialToken)
.map(new SessionProvider[F](_, refreshSecondsBeforeExpiration))
}

private final case class ClientCredentialsResponse(access_token: String, expires_in: Long)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,11 @@ object CachedResource {
}

object ConcurrentCachedObject {
def apply[F[_]: Async, R](acquire: F[R]): F[CachedResource[F, R]] =
Async[F].delay(new ConcurrentCachedObject[F, R](acquire))
def apply[F[_]: Async, R](acquire: F[R], init: Option[R] = None): F[CachedResource[F, R]] =
Async[F]
.delay(new ConcurrentCachedObject[F, R](acquire))
.flatTap(res => init.map(res.initWith).getOrElse(Async[F].unit))
.map(x => x: CachedResource[F, R])
}

// private ctor because of Ref.unsafe in class body, `new` needs `F.delay` around it
Expand Down Expand Up @@ -136,6 +139,10 @@ class ConcurrentCachedObject[F[_], R] private (acquire: F[R])(
// Using `unsafe` just so that I can have RState be an inner type, to avoid useless type parameters on RState
private val cache = Ref.unsafe[F, RState](Empty)

// only meant to be called right after constructor and before any other method
private def initWith(r: R): F[Unit] =
cache.set(Ready(r))

// empty parens to disambiguate the overload
private def transition[A](f: RState => (RState, F[A])): F[A] =
cache.modify(f).flatten
Expand Down
Loading

0 comments on commit 46bd25c

Please sign in to comment.