From 9ec9705eea0f4c2c0ab5e144ec76664a48977340 Mon Sep 17 00:00:00 2001 From: mikrise2 Date: Thu, 12 Oct 2023 03:36:57 +0200 Subject: [PATCH] [ML4SE-137] added thresholds to the config. --- .../tasktracker/config/emotion/Emotion.kt | 11 ++++++++ .../config/emotion/EmotionConfig.kt | 27 ++++++++++++++++++- .../modelInference/EmoPredictor.kt | 19 +++++++------ .../modelInference/model/EmoModel.kt | 16 ++++++----- .../tracking/webcam/WebCamService.kt | 21 --------------- .../tracking/webcam/WebCamTracker.kt | 3 --- .../tracking/webcam/WebCamUtils.kt | 6 ++--- .../ui/main/panel/MainPluginPanelFactory.kt | 5 +++- .../tasktracker/config/emotion_default.yaml | 12 +++++++++ 9 files changed, 73 insertions(+), 47 deletions(-) diff --git a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/config/emotion/Emotion.kt b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/config/emotion/Emotion.kt index 6410a4b4..3ba2942a 100644 --- a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/config/emotion/Emotion.kt +++ b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/config/emotion/Emotion.kt @@ -4,6 +4,17 @@ import kotlinx.serialization.Serializable @Serializable open class Emotion( + /** + * Represents position in the model. + * **ATTENTION** only positive positions will be used on prediction. + * Other ones can be used for the default state. + */ val modelPosition: Int, val name: String, + /** + * When we try to predict an emotion, we need a threshold, + * after crossing which we can say for sure that this is the desired emotion. + * So emotion is first prediction with the result >= threshold. + */ + val threshold: Double ) diff --git a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/config/emotion/EmotionConfig.kt b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/config/emotion/EmotionConfig.kt index 51132fe4..1fbddbb9 100644 --- a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/config/emotion/EmotionConfig.kt +++ b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/config/emotion/EmotionConfig.kt @@ -1,6 +1,7 @@ package org.jetbrains.research.tasktracker.config.emotion import kotlinx.serialization.Serializable +import kotlinx.serialization.Transient import org.jetbrains.research.tasktracker.config.BaseConfig import org.jetbrains.research.tasktracker.config.YamlConfigLoadStrategy import org.jetbrains.research.tasktracker.handler.BaseHandler @@ -9,13 +10,37 @@ import java.io.File @Serializable data class EmotionConfig( - val emotions: List, + /** + * All available emotions for tracking + */ + private val emotions: List, + /** + * .onnx model filename + */ + val modelFilename: String, + /** + * Input point name in onnx model + */ + val modelInputGate: String, + /** + * Output point name in onnx model + */ + val modelOutputGate: String ) : BaseConfig { override val configName: String get() = "emotion" override fun buildHandler(): BaseHandler = EmotionHandler(this) + @Transient + private val modelPositionToEmotion = emotions.associateBy { it.modelPosition } + + @Transient + val modelPositionToThreshold = + emotions.filter { it.modelPosition >= 0 }.associate { it.modelPosition to it.threshold } + + fun getEmotion(modelPosition: Int) = modelPositionToEmotion[modelPosition] + companion object { const val CONFIG_FILE_PREFIX: String = "emotion" diff --git a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/modelInference/EmoPredictor.kt b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/modelInference/EmoPredictor.kt index a989da6e..d01314bc 100644 --- a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/modelInference/EmoPredictor.kt +++ b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/modelInference/EmoPredictor.kt @@ -1,19 +1,15 @@ package org.jetbrains.research.tasktracker.modelInference +import org.jetbrains.research.tasktracker.config.emotion.EmotionConfig import org.opencv.core.Mat -class EmoPrediction(val probabilities: Map) { - - companion object { - private const val THRESHOLD = 0.1 - private val SENSITIVE_RANGE: List = (7 downTo 2).toList() - } +class EmoPrediction(val probabilities: Map, private val thresholds: Map) { fun getPrediction(): Int { - for (i in SENSITIVE_RANGE) { - val probability = probabilities[i] ?: error("probability by index `$i` should exist") - if (probability >= THRESHOLD) { - return i + for (threshold in thresholds.entries) { + val probability = probabilities[threshold.key] ?: error("probability by index `$threshold` should exist") + if (probability >= threshold.value) { + return threshold.key } } @@ -22,5 +18,8 @@ class EmoPrediction(val probabilities: Map) { } interface EmoPredictor { + + val emotionConfig: EmotionConfig + suspend fun predict(image: Mat): EmoPrediction } diff --git a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/modelInference/model/EmoModel.kt b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/modelInference/model/EmoModel.kt index ae4be998..30edf4b2 100644 --- a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/modelInference/model/EmoModel.kt +++ b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/modelInference/model/EmoModel.kt @@ -6,13 +6,14 @@ import io.kinference.core.data.tensor.asTensor import io.kinference.core.model.KIModel import io.kinference.ndarray.arrays.FloatNDArray import kotlinx.coroutines.runBlocking +import org.jetbrains.research.tasktracker.config.emotion.EmotionConfig import org.jetbrains.research.tasktracker.modelInference.EmoPrediction import org.jetbrains.research.tasktracker.modelInference.EmoPredictor import org.jetbrains.research.tasktracker.modelInference.getPixel import org.jetbrains.research.tasktracker.modelInference.prepare import org.opencv.core.Mat -class EmoModel : EmoPredictor { +class EmoModel(override val emotionConfig: EmotionConfig) : EmoPredictor { init { runBlocking { @@ -24,7 +25,8 @@ class EmoModel : EmoPredictor { private suspend fun loadModel() { model = KIEngine.loadModel( EmoModel::class.java - .getResource(MODEL_PATH)?.readBytes() ?: error("$MODEL_PATH must exist") + .getResource(emotionConfig.modelFilename)?.readBytes() + ?: error("${emotionConfig.modelFilename} must exist") ) } @@ -33,18 +35,18 @@ class EmoModel : EmoPredictor { val tensor = FloatNDArray(INPUT_SHAPE) { idx: IntArray -> getPixel(idx, prepImage) } - // TODO Rewrite to constants - val outputs = model.predict(listOf(tensor.asTensor("Input3"))) - val output = outputs["Plus692_Output_0"] + + val outputs = model.predict(listOf(tensor.asTensor(emotionConfig.modelInputGate))) + val output = outputs[emotionConfig.modelOutputGate] val softmaxOutput = ((output as KITensor).data as FloatNDArray).softmax() val outputArray = softmaxOutput.array.toArray() val probabilities = outputArray.mapIndexed { index: Int, prob: Float -> index to prob.toDouble() }.toMap() - return EmoPrediction(probabilities) + return EmoPrediction(probabilities, emotionConfig.modelPositionToThreshold) } + // TODO maybe we need to find a better solution for face detection? companion object { - private const val MODEL_PATH = "emotion-ferplus-18.onnx" private val INPUT_SHAPE = intArrayOf(1, 1, 64, 64) } } diff --git a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/tracking/webcam/WebCamService.kt b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/tracking/webcam/WebCamService.kt index 993ccf1d..b454bb4e 100644 --- a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/tracking/webcam/WebCamService.kt +++ b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/tracking/webcam/WebCamService.kt @@ -1,12 +1,9 @@ package org.jetbrains.research.tasktracker.tracking.webcam import com.intellij.openapi.Disposable -import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.components.Service import com.intellij.openapi.diagnostic.Logger import kotlinx.coroutines.runBlocking -import org.jetbrains.research.tasktracker.actions.tracking.NotificationIcons -import org.jetbrains.research.tasktracker.actions.tracking.NotificationWrapper import org.jetbrains.research.tasktracker.modelInference.EmoPredictor import org.jetbrains.research.tasktracker.tracking.logger.WebCamLogger import java.util.* @@ -31,10 +28,6 @@ class WebCamService : Disposable { runBlocking { it.guessEmotionAndLog(emoPredictor, trackerLogger) } - if (photosMade == PHOTOS_MADE_BEFORE_NOTIFICATION) { - showNotification() - photosMade = 0 - } } }, TIME_TO_PHOTO_DELAY, @@ -46,22 +39,8 @@ class WebCamService : Disposable { timerToMakePhoto.cancel() } - private fun showNotification() { - ApplicationManager.getApplication().invokeAndWait { - ApplicationManager.getApplication().runReadAction { - NotificationWrapper( - NotificationIcons.feedbackNotificationIcon, - NotificationWrapper.FEEDBACK_TEXT, - null - ).show() - } - } - } - companion object { // 5 sec private const val TIME_TO_PHOTO_DELAY = 5000L - - private const val PHOTOS_MADE_BEFORE_NOTIFICATION: Int = (60 * 60 / (TIME_TO_PHOTO_DELAY / 1000)).toInt() } } diff --git a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/tracking/webcam/WebCamTracker.kt b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/tracking/webcam/WebCamTracker.kt index a6a40079..a2be3e0b 100644 --- a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/tracking/webcam/WebCamTracker.kt +++ b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/tracking/webcam/WebCamTracker.kt @@ -4,13 +4,10 @@ import com.intellij.openapi.project.Project import org.jetbrains.research.tasktracker.modelInference.EmoPredictor import org.jetbrains.research.tasktracker.tracking.BaseTracker import org.jetbrains.research.tasktracker.tracking.logger.WebCamLogger -import java.util.* class WebCamTracker(private val project: Project, private val emoPredictor: EmoPredictor) : BaseTracker() { override val trackerLogger: WebCamLogger = WebCamLogger(project) - private val timerToMakePhoto = Timer() - override fun startTracking() { project.getService(WebCamService::class.java).startTakingPhotos(emoPredictor, trackerLogger) } diff --git a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/tracking/webcam/WebCamUtils.kt b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/tracking/webcam/WebCamUtils.kt index 29447bb1..b9fa4e6b 100644 --- a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/tracking/webcam/WebCamUtils.kt +++ b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/tracking/webcam/WebCamUtils.kt @@ -3,7 +3,6 @@ package org.jetbrains.research.tasktracker.tracking.webcam import com.intellij.openapi.diagnostic.Logger import kotlinx.coroutines.runBlocking import nu.pattern.OpenCV -import org.jetbrains.research.tasktracker.TaskTrackerPlugin import org.jetbrains.research.tasktracker.config.MainTaskTrackerConfig import org.jetbrains.research.tasktracker.modelInference.EmoPredictor import org.jetbrains.research.tasktracker.tracking.logger.WebCamLogger @@ -89,10 +88,9 @@ suspend fun Mat.guessEmotionAndLog(emoPredictor: EmoPredictor, webcamLogger: Web val photoDate = DateTime.now() val prediction = emoPredictor.predict(this) val modelScore = prediction.getPrediction() - // TODO - TaskTrackerPlugin.mainConfig.emotionConfig!!.emotions.find { it.modelPosition == modelScore }!!.let { + emoPredictor.emotionConfig.getEmotion(modelScore)?.let { webcamLogger.log(it, prediction.probabilities, isRegular, photoDate) - } + } ?: error("can't find emotion with model position `$modelScore`") } catch (e: IllegalStateException) { WebCamUtils.logger.error(e) } diff --git a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/ui/main/panel/MainPluginPanelFactory.kt b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/ui/main/panel/MainPluginPanelFactory.kt index 75f49b6e..4e0daa75 100644 --- a/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/ui/main/panel/MainPluginPanelFactory.kt +++ b/ijPlugin/src/main/kotlin/org/jetbrains/research/tasktracker/ui/main/panel/MainPluginPanelFactory.kt @@ -22,6 +22,7 @@ import io.ktor.http.* import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking +import org.jetbrains.research.tasktracker.TaskTrackerPlugin import org.jetbrains.research.tasktracker.config.content.task.base.Task import org.jetbrains.research.tasktracker.config.content.task.base.TaskWithFiles import org.jetbrains.research.tasktracker.modelInference.model.EmoModel @@ -113,7 +114,9 @@ class MainPluginPanelFactory : ToolWindowFactory { trackers.clear() // TODO: make better shared loggers - GlobalPluginStorage.emoPredictor = EmoModel() + TaskTrackerPlugin.mainConfig.emotionConfig?.let { + GlobalPluginStorage.emoPredictor = EmoModel(it) + } ?: error("emotion config must exist by this moment") val webCamTracker = WebCamTracker(project, GlobalPluginStorage.emoPredictor!!) trackers.addAll( listOf( diff --git a/ijPlugin/src/main/resources/org/jetbrains/research/tasktracker/config/emotion_default.yaml b/ijPlugin/src/main/resources/org/jetbrains/research/tasktracker/config/emotion_default.yaml index d7646f24..61f55cb2 100644 --- a/ijPlugin/src/main/resources/org/jetbrains/research/tasktracker/config/emotion_default.yaml +++ b/ijPlugin/src/main/resources/org/jetbrains/research/tasktracker/config/emotion_default.yaml @@ -1,27 +1,39 @@ +modelFilename: "emotion-ferplus-18.onnx" +modelInputGate: "Input3" +modelOutputGate: "Plus692_Output_0" emotions: - modelPosition: -1 name: "DEFAULT" + threshold: 0.9 - modelPosition: 0 name: "NEUTRAL" + threshold: 0.9 - modelPosition: 1 name: "HAPPINESS" + threshold: 0.9 - modelPosition: 2 name: "SURPRISE" + threshold: 0.1 - modelPosition: 3 name: "SADNESS" + threshold: 0.1 - modelPosition: 4 name: "ANGER" + threshold: 0.1 - modelPosition: 5 name: "DISGUST" + threshold: 0.1 - modelPosition: 6 name: "FEAR" + threshold: 0.1 - modelPosition: 7 name: "CONTEMPT" + threshold: 0.1