Skip to content

Commit

Permalink
[ML4SE-137] added thresholds to the config.
Browse files Browse the repository at this point in the history
  • Loading branch information
mikrise2 committed Oct 12, 2023
1 parent 05d20e4 commit 9ec9705
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@ import kotlinx.serialization.Serializable

@Serializable
open class Emotion(
/**
* Represents position in the model.
* **ATTENTION** only positive positions will be used on prediction.
* Other ones can be used for the default state.
*/
val modelPosition: Int,
val name: String,
/**
* When we try to predict an emotion, we need a threshold,
* after crossing which we can say for sure that this is the desired emotion.
* So emotion is first prediction with the result >= threshold.
*/
val threshold: Double
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jetbrains.research.tasktracker.config.emotion

import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient
import org.jetbrains.research.tasktracker.config.BaseConfig
import org.jetbrains.research.tasktracker.config.YamlConfigLoadStrategy
import org.jetbrains.research.tasktracker.handler.BaseHandler
Expand All @@ -9,13 +10,37 @@ import java.io.File

@Serializable
data class EmotionConfig(
val emotions: List<Emotion>,
/**
* All available emotions for tracking
*/
private val emotions: List<Emotion>,
/**
* .onnx model filename
*/
val modelFilename: String,
/**
* Input point name in onnx model
*/
val modelInputGate: String,
/**
* Output point name in onnx model
*/
val modelOutputGate: String
) : BaseConfig {
override val configName: String
get() = "emotion"

override fun buildHandler(): BaseHandler = EmotionHandler(this)

@Transient
private val modelPositionToEmotion = emotions.associateBy { it.modelPosition }

@Transient
val modelPositionToThreshold =
emotions.filter { it.modelPosition >= 0 }.associate { it.modelPosition to it.threshold }

fun getEmotion(modelPosition: Int) = modelPositionToEmotion[modelPosition]

companion object {
const val CONFIG_FILE_PREFIX: String = "emotion"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
package org.jetbrains.research.tasktracker.modelInference

import org.jetbrains.research.tasktracker.config.emotion.EmotionConfig
import org.opencv.core.Mat

class EmoPrediction(val probabilities: Map<Int, Double>) {

companion object {
private const val THRESHOLD = 0.1
private val SENSITIVE_RANGE: List<Int> = (7 downTo 2).toList()
}
class EmoPrediction(val probabilities: Map<Int, Double>, private val thresholds: Map<Int, Double>) {

fun getPrediction(): Int {
for (i in SENSITIVE_RANGE) {
val probability = probabilities[i] ?: error("probability by index `$i` should exist")
if (probability >= THRESHOLD) {
return i
for (threshold in thresholds.entries) {
val probability = probabilities[threshold.key] ?: error("probability by index `$threshold` should exist")
if (probability >= threshold.value) {
return threshold.key
}
}

Expand All @@ -22,5 +18,8 @@ class EmoPrediction(val probabilities: Map<Int, Double>) {
}

interface EmoPredictor {

val emotionConfig: EmotionConfig

suspend fun predict(image: Mat): EmoPrediction
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ import io.kinference.core.data.tensor.asTensor
import io.kinference.core.model.KIModel
import io.kinference.ndarray.arrays.FloatNDArray
import kotlinx.coroutines.runBlocking
import org.jetbrains.research.tasktracker.config.emotion.EmotionConfig
import org.jetbrains.research.tasktracker.modelInference.EmoPrediction
import org.jetbrains.research.tasktracker.modelInference.EmoPredictor
import org.jetbrains.research.tasktracker.modelInference.getPixel
import org.jetbrains.research.tasktracker.modelInference.prepare
import org.opencv.core.Mat

class EmoModel : EmoPredictor {
class EmoModel(override val emotionConfig: EmotionConfig) : EmoPredictor {

init {
runBlocking {
Expand All @@ -24,7 +25,8 @@ class EmoModel : EmoPredictor {
private suspend fun loadModel() {
model = KIEngine.loadModel(
EmoModel::class.java
.getResource(MODEL_PATH)?.readBytes() ?: error("$MODEL_PATH must exist")
.getResource(emotionConfig.modelFilename)?.readBytes()
?: error("${emotionConfig.modelFilename} must exist")
)
}

Expand All @@ -33,18 +35,18 @@ class EmoModel : EmoPredictor {
val tensor = FloatNDArray(INPUT_SHAPE) { idx: IntArray ->
getPixel(idx, prepImage)
}
// TODO Rewrite to constants
val outputs = model.predict(listOf(tensor.asTensor("Input3")))
val output = outputs["Plus692_Output_0"]

val outputs = model.predict(listOf(tensor.asTensor(emotionConfig.modelInputGate)))
val output = outputs[emotionConfig.modelOutputGate]
val softmaxOutput = ((output as KITensor).data as FloatNDArray).softmax()
val outputArray = softmaxOutput.array.toArray()

val probabilities = outputArray.mapIndexed { index: Int, prob: Float -> index to prob.toDouble() }.toMap()
return EmoPrediction(probabilities)
return EmoPrediction(probabilities, emotionConfig.modelPositionToThreshold)
}

// TODO maybe we need to find a better solution for face detection?
companion object {
private const val MODEL_PATH = "emotion-ferplus-18.onnx"
private val INPUT_SHAPE = intArrayOf(1, 1, 64, 64)
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package org.jetbrains.research.tasktracker.tracking.webcam

import com.intellij.openapi.Disposable
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.components.Service
import com.intellij.openapi.diagnostic.Logger
import kotlinx.coroutines.runBlocking
import org.jetbrains.research.tasktracker.actions.tracking.NotificationIcons
import org.jetbrains.research.tasktracker.actions.tracking.NotificationWrapper
import org.jetbrains.research.tasktracker.modelInference.EmoPredictor
import org.jetbrains.research.tasktracker.tracking.logger.WebCamLogger
import java.util.*
Expand All @@ -31,10 +28,6 @@ class WebCamService : Disposable {
runBlocking {
it.guessEmotionAndLog(emoPredictor, trackerLogger)
}
if (photosMade == PHOTOS_MADE_BEFORE_NOTIFICATION) {
showNotification()
photosMade = 0
}
}
},
TIME_TO_PHOTO_DELAY,
Expand All @@ -46,22 +39,8 @@ class WebCamService : Disposable {
timerToMakePhoto.cancel()
}

private fun showNotification() {
ApplicationManager.getApplication().invokeAndWait {
ApplicationManager.getApplication().runReadAction {
NotificationWrapper(
NotificationIcons.feedbackNotificationIcon,
NotificationWrapper.FEEDBACK_TEXT,
null
).show()
}
}
}

companion object {
// 5 sec
private const val TIME_TO_PHOTO_DELAY = 5000L

private const val PHOTOS_MADE_BEFORE_NOTIFICATION: Int = (60 * 60 / (TIME_TO_PHOTO_DELAY / 1000)).toInt()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@ import com.intellij.openapi.project.Project
import org.jetbrains.research.tasktracker.modelInference.EmoPredictor
import org.jetbrains.research.tasktracker.tracking.BaseTracker
import org.jetbrains.research.tasktracker.tracking.logger.WebCamLogger
import java.util.*

class WebCamTracker(private val project: Project, private val emoPredictor: EmoPredictor) : BaseTracker() {
override val trackerLogger: WebCamLogger = WebCamLogger(project)

private val timerToMakePhoto = Timer()

override fun startTracking() {
project.getService(WebCamService::class.java).startTakingPhotos(emoPredictor, trackerLogger)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package org.jetbrains.research.tasktracker.tracking.webcam
import com.intellij.openapi.diagnostic.Logger
import kotlinx.coroutines.runBlocking
import nu.pattern.OpenCV
import org.jetbrains.research.tasktracker.TaskTrackerPlugin
import org.jetbrains.research.tasktracker.config.MainTaskTrackerConfig
import org.jetbrains.research.tasktracker.modelInference.EmoPredictor
import org.jetbrains.research.tasktracker.tracking.logger.WebCamLogger
Expand Down Expand Up @@ -89,10 +88,9 @@ suspend fun Mat.guessEmotionAndLog(emoPredictor: EmoPredictor, webcamLogger: Web
val photoDate = DateTime.now()
val prediction = emoPredictor.predict(this)
val modelScore = prediction.getPrediction()
// TODO
TaskTrackerPlugin.mainConfig.emotionConfig!!.emotions.find { it.modelPosition == modelScore }!!.let {
emoPredictor.emotionConfig.getEmotion(modelScore)?.let {
webcamLogger.log(it, prediction.probabilities, isRegular, photoDate)
}
} ?: error("can't find emotion with model position `$modelScore`")
} catch (e: IllegalStateException) {
WebCamUtils.logger.error(e)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import io.ktor.http.*
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.jetbrains.research.tasktracker.TaskTrackerPlugin
import org.jetbrains.research.tasktracker.config.content.task.base.Task
import org.jetbrains.research.tasktracker.config.content.task.base.TaskWithFiles
import org.jetbrains.research.tasktracker.modelInference.model.EmoModel
Expand Down Expand Up @@ -113,7 +114,9 @@ class MainPluginPanelFactory : ToolWindowFactory {
trackers.clear()

// TODO: make better shared loggers
GlobalPluginStorage.emoPredictor = EmoModel()
TaskTrackerPlugin.mainConfig.emotionConfig?.let {
GlobalPluginStorage.emoPredictor = EmoModel(it)
} ?: error("emotion config must exist by this moment")
val webCamTracker = WebCamTracker(project, GlobalPluginStorage.emoPredictor!!)
trackers.addAll(
listOf(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,39 @@
modelFilename: "emotion-ferplus-18.onnx"
modelInputGate: "Input3"
modelOutputGate: "Plus692_Output_0"
emotions:
- modelPosition: -1
name: "DEFAULT"
threshold: 0.9

- modelPosition: 0
name: "NEUTRAL"
threshold: 0.9

- modelPosition: 1
name: "HAPPINESS"
threshold: 0.9

- modelPosition: 2
name: "SURPRISE"
threshold: 0.1

- modelPosition: 3
name: "SADNESS"
threshold: 0.1

- modelPosition: 4
name: "ANGER"
threshold: 0.1

- modelPosition: 5
name: "DISGUST"
threshold: 0.1

- modelPosition: 6
name: "FEAR"
threshold: 0.1

- modelPosition: 7
name: "CONTEMPT"
threshold: 0.1

0 comments on commit 9ec9705

Please sign in to comment.