Skip to content

Commit

Permalink
Feature/cleanup action (#8)
Browse files Browse the repository at this point in the history
feat: github action + scalafmt update + dummy test
  • Loading branch information
Dennis Madsen authored Mar 11, 2024
1 parent f25acb3 commit 54c2ae7
Show file tree
Hide file tree
Showing 54 changed files with 727 additions and 682 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Scala CI

on:
push:
branches: [ main, release* ]
pull_request:
branches: [ main, release* ]

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up JDK 11
uses: actions/setup-java@v2
with:
java-version: 11
- name: Run tests
run: |
sudo apt-get update
sudo apt-get install -y libopengl0
sbt -Djava.awt.headless=true +compile test
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Formatting
run: sbt scalafmtSbtCheck scalafmtCheck test:scalafmtCheck
16 changes: 12 additions & 4 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
version = "3.5.3"
maxColumn = 120
align = true
align.preset = more
version=3.8.0
project.git = true
runner.dialect = scala3
align.openParenCallSite = true
align.openParenDefnSite = true
maxColumn = 120
continuationIndent.defnSite = 2
assumeStandardLibraryStripMargin = true
danglingParentheses.preset = true
rewrite.rules = [SortImports, SortModifiers]
docstrings.style = Asterisk

onTestFailure = "To fix this, run: sbt scalafmtAll"
3 changes: 1 addition & 2 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
ThisBuild / version := "0.1-SNAPSHOT"

ThisBuild / scalaVersion := "3.3.1"

lazy val root = (project in file("."))
Expand Down Expand Up @@ -37,6 +35,7 @@ lazy val root = (project in file("."))
},
javacOptions ++= Seq("-source", "1.8", "-target", "1.8"),
libraryDependencies ++= Seq(
"org.scalatest" %% "scalatest" % "3.2.18" % "test",
"ch.unibas.cs.gravis" %% "scalismo" % "1.0.0",
"io.spray" %% "spray-json" % "1.3.6",
),
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/gingr/api/CorrespondencePairs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package gingr.api

import scalismo.common.PointId
import scalismo.geometry.{Point, _3D}
import scalismo.geometry.{_3D, Point}

case class CorrespondencePairs(pairs: IndexedSeq[(PointId, Point[_3D])])

Expand Down
85 changes: 43 additions & 42 deletions src/main/scala/gingr/api/GeneralRegistrationState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,30 @@ package gingr.api
import breeze.linalg.{DenseMatrix, DenseVector}
import gingr.api.FittingStatuses.FittingStatus
import scalismo.common.PointId
import scalismo.geometry.{EuclideanVector, Landmark, Point, _3D}
import scalismo.geometry.{_3D, EuclideanVector, Landmark, Point}
import scalismo.mesh.TriangleMesh
import scalismo.statisticalmodel.{MultivariateNormalDistribution, PointDistributionModel}
import scalismo.transformations._

case class GeneralRegistrationState(
override val model: PointDistributionModel[_3D, TriangleMesh],
override val modelParameters: ModelFittingParameters,
override val modelLandmarks: Option[Seq[Landmark[_3D]]] = None,
override val target: TriangleMesh[_3D],
override val targetLandmarks: Option[Seq[Landmark[_3D]]] = None,
override val fit: TriangleMesh[_3D],
override val sigma2: Double = 1.0,
override val globalTransformation: GlobalTranformationType = RigidTransforms,
override val stepLength: Double = 1.0,
override val generatedBy: String = "",
override val iteration: Int = 0,
override val status: FittingStatus = FittingStatuses.None
override val model: PointDistributionModel[_3D, TriangleMesh],
override val modelParameters: ModelFittingParameters,
override val modelLandmarks: Option[Seq[Landmark[_3D]]] = None,
override val target: TriangleMesh[_3D],
override val targetLandmarks: Option[Seq[Landmark[_3D]]] = None,
override val fit: TriangleMesh[_3D],
override val sigma2: Double = 1.0,
override val globalTransformation: GlobalTranformationType = RigidTransforms,
override val stepLength: Double = 1.0,
override val generatedBy: String = "",
override val iteration: Int = 0,
override val status: FittingStatus = FittingStatuses.None
) extends RegistrationState[GeneralRegistrationState] {

lazy val landmarkCorrespondences: IndexedSeq[(PointId, Point[_3D], MultivariateNormalDistribution)] = {
if (modelLandmarks.nonEmpty && targetLandmarks.nonEmpty) {
val m = modelLandmarks.get
val t = targetLandmarks.get
val m = modelLandmarks.get
val t = targetLandmarks.get
val commonLmNames = m.map(_.id) intersect t.map(_.id)
commonLmNames.map { name =>
val mPoint = m.find(_.id == name).get
Expand All @@ -61,15 +61,16 @@ case class GeneralRegistrationState(
}
}

/** Updates the current state with the new fit.
*
* @param next
* The newly calculated shape / fit.
*/
/**
* Updates the current state with the new fit.
*
* @param next
* The newly calculated shape / fit.
*/
override def updateFit(next: TriangleMesh[_3D]): GeneralRegistrationState = this.copy(fit = next)
override def updateIteration(): GeneralRegistrationState = this.copy(iteration = this.iteration + 1)
override def clearIteration(): GeneralRegistrationState = this.copy(iteration = 0)
override def updateStatus(next: FittingStatus): GeneralRegistrationState = this.copy(status = next)
override def updateIteration(): GeneralRegistrationState = this.copy(iteration = this.iteration + 1)
override def clearIteration(): GeneralRegistrationState = this.copy(iteration = 0)
override def updateStatus(next: FittingStatus): GeneralRegistrationState = this.copy(status = next)

override private[api] def updateTranslation(next: EuclideanVector[_3D]): GeneralRegistrationState = {
this.copy(modelParameters = this.modelParameters.copy(pose = this.modelParameters.pose.copy(translation = next)))
Expand All @@ -80,7 +81,7 @@ case class GeneralRegistrationState(
}

override private[api] def updateRotation(next: Rotation[_3D]): GeneralRegistrationState = {
val angles = RotationSpace3D.rotMatrixToEulerAngles(next.rotationMatrix)
val angles = RotationSpace3D.rotMatrixToEulerAngles(next.rotationMatrix)
val newEuler = EulerRotation(EulerAngles(angles._1, angles._2, angles._3), next.center)
this.copy(modelParameters = this.modelParameters.copy(pose = this.modelParameters.pose.copy(rotation = newEuler)))
}
Expand Down Expand Up @@ -115,29 +116,29 @@ case class GeneralRegistrationState(

object GeneralRegistrationState {
def apply(
model: PointDistributionModel[_3D, TriangleMesh],
target: TriangleMesh[_3D],
modelTranform: Option[TranslationAfterRotation[_3D]]
model: PointDistributionModel[_3D, TriangleMesh],
target: TriangleMesh[_3D],
modelTranform: Option[TranslationAfterRotation[_3D]]
): GeneralRegistrationState = {
apply(model, target, RigidTransforms, modelTranform)
}

def apply(
model: PointDistributionModel[_3D, TriangleMesh],
target: TriangleMesh[_3D],
transform: GlobalTranformationType,
modelTranform: Option[TranslationAfterRotation[_3D]]
model: PointDistributionModel[_3D, TriangleMesh],
target: TriangleMesh[_3D],
transform: GlobalTranformationType,
modelTranform: Option[TranslationAfterRotation[_3D]]
): GeneralRegistrationState = {
apply(model, Seq(), target, Seq(), transform, modelTranform)
}

def apply(
model: PointDistributionModel[_3D, TriangleMesh],
modelLandmarks: Seq[Landmark[_3D]],
target: TriangleMesh[_3D],
targetLandmarks: Seq[Landmark[_3D]],
transform: GlobalTranformationType,
modelTranform: Option[TranslationAfterRotation[_3D]]
model: PointDistributionModel[_3D, TriangleMesh],
modelLandmarks: Seq[Landmark[_3D]],
target: TriangleMesh[_3D],
targetLandmarks: Seq[Landmark[_3D]],
transform: GlobalTranformationType,
modelTranform: Option[TranslationAfterRotation[_3D]]
): GeneralRegistrationState = {
val (t, r) = if (modelTranform.isDefined) {
val initialAngles = RotationSpace3D.rotMatrixToEulerAngles(modelTranform.get.rotation.rotationMatrix)
Expand Down Expand Up @@ -178,11 +179,11 @@ object GeneralRegistrationState {
}

def apply(
model: PointDistributionModel[_3D, TriangleMesh],
modelLandmarks: Seq[Landmark[_3D]],
target: TriangleMesh[_3D],
targetLandmarks: Seq[Landmark[_3D]],
modelTranform: Option[TranslationAfterRotation[_3D]]
model: PointDistributionModel[_3D, TriangleMesh],
modelLandmarks: Seq[Landmark[_3D]],
target: TriangleMesh[_3D],
targetLandmarks: Seq[Landmark[_3D]],
modelTranform: Option[TranslationAfterRotation[_3D]]
): GeneralRegistrationState = {
apply(model, modelLandmarks, target, targetLandmarks, RigidTransforms, modelTranform)
}
Expand Down
89 changes: 45 additions & 44 deletions src/main/scala/gingr/api/GingrAlgorithm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import gingr.api.sampling.generators.{GeneratorWrapperDeterministic, GeneratorWr
import gingr.api.sampling.loggers.BestAndCurrentSampleLogger
import gingr.api.sampling.{AcceptAll, Evaluator, Generator}
import scalismo.common.PointId
import scalismo.geometry.{Point, _3D}
import scalismo.geometry.{_3D, Point}
import scalismo.mesh.TriangleMesh
import scalismo.registration.LandmarkRegistration
import scalismo.sampling.algorithms.MetropolisHastings
Expand All @@ -43,8 +43,8 @@ import scalismo.utils.{Memoize, Random}
import scala.util.Try

case class ProbabilisticSettings[State <: GingrRegistrationState[State]](
evaluators: Evaluator[State],
randomMixture: Double = 0.5
evaluators: Evaluator[State],
randomMixture: Double = 0.5
) {
require(randomMixture >= 0.0 && randomMixture <= 1.0)
}
Expand All @@ -65,26 +65,26 @@ trait GingrRegistrationState[State] {
trait GingrAlgorithm[State <: GingrRegistrationState[State], config <: GingrConfig] {
val getCorrespondence: State => CorrespondencePairs
val getUncertainty: (PointId, State) => MultivariateNormalDistribution
private val cashedPosterior = Memoize(computePosterior, 10)
private val cashedPosterior = Memoize(computePosterior, 10)
private val retryCounterInitialize = 10
private var retryCounter = 10
private var retryCounter = 10

def name: String

def initializeState(general: GeneralRegistrationState, config: config): State

def estimateRigidTransform(
current: TriangleMesh[_3D],
update: IndexedSeq[(PointId, Point[_3D])]
current: TriangleMesh[_3D],
update: IndexedSeq[(PointId, Point[_3D])]
): TranslationAfterScalingAfterRotation[_3D] = {
val pairs = update.map(f => (current.pointSet.point(f._1), f._2))
val t = LandmarkRegistration.rigid3DLandmarkRegistration(pairs, Point(0, 0, 0))
val t = LandmarkRegistration.rigid3DLandmarkRegistration(pairs, Point(0, 0, 0))
TranslationAfterScalingAfterRotation(t.translation, Scaling(1.0), t.rotation)
}

def estimateSimilarityTransform(
current: TriangleMesh[_3D],
update: IndexedSeq[(PointId, Point[_3D])]
current: TriangleMesh[_3D],
update: IndexedSeq[(PointId, Point[_3D])]
): TranslationAfterScalingAfterRotation[_3D] = {
val pairs = update.map(f => (current.pointSet.point(f._1), f._2))
LandmarkRegistration.similarity3DLandmarkRegistration(pairs, Point(0, 0, 0))
Expand All @@ -94,34 +94,35 @@ trait GingrAlgorithm[State <: GingrRegistrationState[State], config <: GingrConf
model.instance(state.modelParameters.shape.parameters).transform(state.modelParameters.similarityTransform)
}

/** Runs the actual registration with the provided configuration through the passed parameters.
*
* @param initialState
* State from which the registration is started.
* @param callBackLogger
* Logger to provide call back functionality to user after each iteration
* @param acceptRejectLogger
* Logger to use for advanced file logging
* @param probabilisticSettings
* Evaluator to be used if probabilistic registration is set
* @param generators
* Pass in external generators to use
* @param rnd
* Implicit random number generator.
* @return
* Returns the best sample of the chain given the evaluator..
*/
/**
* Runs the actual registration with the provided configuration through the passed parameters.
*
* @param initialState
* State from which the registration is started.
* @param callBackLogger
* Logger to provide call back functionality to user after each iteration
* @param acceptRejectLogger
* Logger to use for advanced file logging
* @param probabilisticSettings
* Evaluator to be used if probabilistic registration is set
* @param generators
* Pass in external generators to use
* @param rnd
* Implicit random number generator.
* @return
* Returns the best sample of the chain given the evaluator..
*/
def run(
initialState: State,
acceptRejectLogger: Option[AcceptRejectLogger[State]],
callBackLogger: Option[ChainStateLogger[State]],
probabilisticSettings: Option[ProbabilisticSettings[State]],
generators: Option[ProposalGenerator[State] with TransitionProbability[State]] = None
initialState: State,
acceptRejectLogger: Option[AcceptRejectLogger[State]],
callBackLogger: Option[ChainStateLogger[State]],
probabilisticSettings: Option[ProbabilisticSettings[State]],
generators: Option[ProposalGenerator[State] with TransitionProbability[State]] = None
)(implicit rnd: Random): State = {
val evaluator = probabilisticSettings.getOrElse(ProbabilisticSettings(AcceptAll(), randomMixture = 0.0))
val evaluator = probabilisticSettings.getOrElse(ProbabilisticSettings(AcceptAll(), randomMixture = 0.0))
val registrationEvaluator = EvaluatorWrapper(probabilisticSettings.nonEmpty, evaluator.evaluators)
val registrationGenerator = generatorCombined(probabilisticSettings, generators)
val bestSampleLogger = BestAndCurrentSampleLogger[State](registrationEvaluator)
val bestSampleLogger = BestAndCurrentSampleLogger[State](registrationEvaluator)
val logs =
if (callBackLogger.isDefined) ChainStateLoggerContainer(Seq(callBackLogger.get, bestSampleLogger))
else bestSampleLogger
Expand All @@ -133,7 +134,7 @@ trait GingrAlgorithm[State <: GingrRegistrationState[State], config <: GingrConf
else mhChain.iterator(initialState).loggedWith(logs)

var converged: Boolean = false
var error: Boolean = false
var error: Boolean = false
// we need to query if there is a next element, otherwise due to laziness the chain is not calculated
var currentState: Option[GeneralRegistrationState] = None
states
Expand Down Expand Up @@ -174,8 +175,8 @@ trait GingrAlgorithm[State <: GingrRegistrationState[State], config <: GingrConf
}

def generatorCombined(
probabilisticSettings: Option[ProbabilisticSettings[State]],
mixing: Option[ProposalGenerator[State] with TransitionProbability[State]]
probabilisticSettings: Option[ProbabilisticSettings[State]],
mixing: Option[ProposalGenerator[State] with TransitionProbability[State]]
)(implicit rnd: Random): ProposalGenerator[State] with TransitionProbability[State] = {
probabilisticSettings match {
case Some(setting) =>
Expand Down Expand Up @@ -207,7 +208,7 @@ trait GingrAlgorithm[State <: GingrRegistrationState[State], config <: GingrConf
}
} else {
retryCounter = math.min(retryCounterInitialize, retryCounter + 1)
val shapeproposal = if (!probabilistic) posterior.get.mean else posterior.get.sample()
val shapeproposal = if (!probabilistic) posterior.get.mean else posterior.get.sample()
val transformedModelInit = current.general.model.transform(current.general.modelParameters.rigidTransform)

val newCoefficients = Try(
Expand Down Expand Up @@ -240,7 +241,7 @@ trait GingrAlgorithm[State <: GingrRegistrationState[State], config <: GingrConf
.updateRotation(newGlobalAlignment.rotation)
.updateScaling(ScaleParameter(globalTransform.scaling.s))
.updateShapeParameters(ShapeParameters(alpha.get))
val newState = current.updateGeneral(general)
val newState = current.updateGeneral(general)
val newSigma2 = updateSigma2(newState)
newState.updateGeneral(newState.general.updateSigma2(newSigma2))
} else {
Expand All @@ -257,8 +258,8 @@ trait GingrAlgorithm[State <: GingrRegistrationState[State], config <: GingrConf
}

def estimateRigidTransform(
current: TriangleMesh[_3D],
update: TriangleMesh[_3D]
current: TriangleMesh[_3D],
update: TriangleMesh[_3D]
): TranslationAfterScalingAfterRotation[_3D] = {
val t = LandmarkRegistration.rigid3DLandmarkRegistration(
current.pointSet.points.toSeq.zip(update.pointSet.points.toSeq),
Expand All @@ -268,8 +269,8 @@ trait GingrAlgorithm[State <: GingrRegistrationState[State], config <: GingrConf
}

def estimateSimilarityTransform(
current: TriangleMesh[_3D],
update: TriangleMesh[_3D]
current: TriangleMesh[_3D],
update: TriangleMesh[_3D]
): TranslationAfterScalingAfterRotation[_3D] = {
LandmarkRegistration.similarity3DLandmarkRegistration(
current.pointSet.points.toSeq.zip(update.pointSet.points.toSeq),
Expand All @@ -281,7 +282,7 @@ trait GingrAlgorithm[State <: GingrRegistrationState[State], config <: GingrConf
val correspondences = getCorrespondence(current)
val correspondencesWithUncertainty = correspondences.pairs.map { pair =>
val (pid, point) = pair
val uncertainty = getUncertainty(pid, current)
val uncertainty = getUncertainty(pid, current)
(pid, point, uncertainty)
}
val observationsWithUncertainty = if (current.config.useLandmarkCorrespondence) {
Expand Down
Loading

0 comments on commit 54c2ae7

Please sign in to comment.