Skip to content

Commit

Permalink
Proper iinc fix (#168)
Browse files Browse the repository at this point in the history
JacoDB produces incorrect 3-address code for IINC instruction #146
  • Loading branch information
lehvolk authored Sep 5, 2023
1 parent 89d185c commit 27a25a0
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ class RawInstListBuilder(
private val laterAssignments = identityMap<AbstractInsnNode, MutableMap<Int, JcRawValue>>()
private val laterStackAssignments = identityMap<AbstractInsnNode, MutableMap<Int, JcRawValue>>()
private val localTypeRefinement = identityMap<JcRawLocalVar, JcRawLocalVar>()
private val postfixInstructions = hashMapOf<Int, JcRawInst>()

private var labelCounter = 0
private var localCounter = 0
Expand All @@ -213,12 +212,13 @@ class RawInstListBuilder(
private fun buildInstructions() {
currentFrame = createInitialFrame()
frames[ENTRY] = currentFrame
methodNode.instructions.forEachIndexed { index, insn ->
val nodes = methodNode.instructions.toList()
nodes.forEachIndexed { index, insn ->
when (insn) {
is InsnNode -> buildInsnNode(insn)
is FieldInsnNode -> buildFieldInsnNode(insn)
is FrameNode -> buildFrameNode(insn)
is IincInsnNode -> buildIincInsnNode(insn)
is IincInsnNode -> buildIincInsnNode(insn, nodes.getOrNull(index + 1))
is IntInsnNode -> buildIntInsnNode(insn)
is InvokeDynamicInsnNode -> buildInvokeDynamicInsn(insn)
is JumpInsnNode -> buildJumpInsnNode(insn)
Expand Down Expand Up @@ -250,11 +250,14 @@ class RawInstListBuilder(
val insnList = instructionList(insn)
val frame = frames[insn]!!
for ((variable, value) in assignments) {
if (value != frame[variable]) {
if (insn.isBranchingInst || insn.isTerminateInst) {
insnList.addInst(insn, JcRawAssignInst(method, value, frame[variable]!!), insnList.lastIndex)
val frameVariable = frame[variable]
if (frameVariable != null && value != frameVariable) {
if (insn.isBranchingInst) {
insnList.addInst(insn, JcRawAssignInst(method, value, frameVariable), 0)
}else if(insn.isTerminateInst) {
insnList.addInst(insn, JcRawAssignInst(method, value, frameVariable), insnList.lastIndex)
} else {
insnList.addInst(insn, JcRawAssignInst(method, value, frame[variable]!!))
insnList.addInst(insn, JcRawAssignInst(method, value, frameVariable))
}
}
}
Expand Down Expand Up @@ -375,11 +378,16 @@ class RawInstListBuilder(
return currentFrame.locals.getValue(variable)
}

private fun local(variable: Int, expr: JcRawValue, insn: AbstractInsnNode): JcRawAssignInst? {
private fun local(variable: Int, expr: JcRawValue, insn: AbstractInsnNode, override: Boolean = false): JcRawAssignInst? {
val oldVar = currentFrame.locals[variable]
return if (oldVar != null) {
if (oldVar.typeName == expr.typeName || (expr is JcRawNullConstant && !oldVar.typeName.isPrimitive)) {
JcRawAssignInst(method, oldVar, expr)
if (override) {
currentFrame = currentFrame.put(variable, expr)
JcRawAssignInst(method, expr, expr)
} else {
JcRawAssignInst(method, oldVar, expr)
}
} else if (expr is JcRawSimpleValue) {
currentFrame = currentFrame.put(variable, expr)
null
Expand All @@ -403,7 +411,7 @@ class RawInstListBuilder(
private fun instructionList(insn: AbstractInsnNode) = instructions.getOrPut(insn, ::mutableListOf)

private fun addInstruction(insn: AbstractInsnNode, inst: JcRawInst, index: Int? = null) {
instructionList(insn).addInst(insn, inst, index)
instructionList(insn).addInst(insn, inst, index)
}

private fun MutableList<JcRawInst>.addInst(node: AbstractInsnNode, inst: JcRawInst, index: Int? = null) {
Expand All @@ -412,18 +420,6 @@ class RawInstListBuilder(
} else {
add(inst)
}
if (postfixInstructions.isNotEmpty()) {
when {
node.isBranchingInst -> postfixInstructions.forEach {
instructionList(node).add(0, it.value)
}

inst !is JcRawReturnInst -> postfixInstructions.forEach {
instructionList(node).add(it.value)
}
}
postfixInstructions.clear()
}
}

private fun nextRegister(typeName: TypeName): JcRawValue {
Expand Down Expand Up @@ -825,7 +821,7 @@ class RawInstListBuilder(
* a helper function that helps to merge local variables from several predecessor frames into one map
* if all the predecessor frames are known (meaning we already visited all the corresponding instructions
* in the bytecode) --- merge process is trivial
* if some predecessor frames are unknown, we remebmer them and add requried assignment instructions after
* if some predecessor frames are unknown, we remember them and add required assignment instructions after
* the full construction process is complete, see #buildRequiredAssignments function
*/
private fun SortedMap<Int, TypeName>.copyLocals(predFrames: Map<AbstractInsnNode, Frame?>): Map<Int, JcRawValue> =
Expand Down Expand Up @@ -1017,13 +1013,18 @@ class RawInstListBuilder(
}
}

private fun buildIincInsnNode(insnNode: IincInsnNode) {
private fun buildIincInsnNode(insnNode: IincInsnNode, nextInst: AbstractInsnNode?) {
val variable = insnNode.`var`
val local = local(variable)
postfixInstructions[variable] = JcRawAssignInst(method, local,
JcRawAddExpr(local.typeName, local, JcRawInt(insnNode.incr))
)
local(variable, local, insnNode)
val incrementedVariable = when {
nextInst != null && nextInst.isBranchingInst -> local
nextInst != null && (
(nextInst is VarInsnNode && nextInst.`var` == variable) || nextInst is LabelNode) -> local
else -> nextRegister(local.typeName)
}
val add = JcRawAddExpr(local.typeName, local, JcRawInt(insnNode.incr))
instructionList(insnNode) += JcRawAssignInst(method, incrementedVariable, add)
local(variable, incrementedVariable, insnNode, override = incrementedVariable != local)
}

private fun buildIntInsnNode(insnNode: IntInsnNode) {
Expand Down Expand Up @@ -1428,10 +1429,6 @@ class RawInstListBuilder(

in Opcodes.ILOAD..Opcodes.ALOAD -> {
push(local(variable))
postfixInstructions[variable]?.let {
postfixInstructions.remove(variable)
instructionList(insnNode).add(it) // do not reuse `addInstruction` function here
}
}
else -> error("Unknown opcode ${insnNode.opcode} in VarInsnNode")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ internal class Simplifier {
}
}


private class SimplifierCollector : AbstractFullRawExprSetCollector() {
val exprs = hashSetOf<JcRawSimpleValue>()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,47 +40,41 @@ abstract class BaseInstructionsTest : BaseTest() {

val ext = runBlocking { cp.hierarchyExt() }

protected fun testClass(klass: JcClassOrInterface) {
testAndLoadClass(klass, false)
protected fun testClass(klass: JcClassOrInterface, validateLineNumbers: Boolean = true) {
testAndLoadClass(klass, false, validateLineNumbers)
}

protected fun testAndLoadClass(klass: JcClassOrInterface): Class<*> {
return testAndLoadClass(klass, true)!!
return testAndLoadClass(klass, true, validateLineNumbers = true)!!
}

private fun testAndLoadClass(klass: JcClassOrInterface, loadClass: Boolean): Class<*>? {
private fun testAndLoadClass(klass: JcClassOrInterface, loadClass: Boolean, validateLineNumbers: Boolean): Class<*>? {
try {
val classNode = klass.asmNode()
classNode.methods = klass.declaredMethods.filter { it.enclosingClass == klass }.map {
if (it.isAbstract || it.name.contains("$\$forInline")) {
it.asmNode()
} else {
try {
// val oldBody = it.body()
// println()
// println("Old body: ${oldBody.print()}")
val instructionList = it.rawInstList
it.instList.forEachIndexed { index, inst ->
Assertions.assertEquals(index, inst.location.index, "indexes not matched for $it at $index")
}
// println("Instruction list: $instructionList")
val graph = it.flowGraph()
if (!it.enclosingClass.isKotlin) {
val methodMsg = "$it should have line number"
graph.instructions.forEach {
Assertions.assertTrue(it.lineNumber > 0, methodMsg)
if (validateLineNumbers) {
graph.instructions.forEach {
Assertions.assertTrue(it.lineNumber > 0, methodMsg)
}
}
}
graph.applyAndGet(OverridesResolver(ext)) {}
JcGraphChecker(it, graph).check()
// println("Graph: $graph")
// graph.view("/usr/bin/dot", "/usr/bin/firefox", false)
// graph.blockGraph().view("/usr/bin/dot", "/usr/bin/firefox")
val newBody = MethodNodeBuilder(it, instructionList).build()
// println("New body: ${newBody.print()}")
// println()
newBody
} catch (e: Exception) {
} catch (e: Throwable) {
it.dumpInstructions()
throw IllegalStateException("error handling $it", e)
}

Expand Down
16 changes: 12 additions & 4 deletions jacodb-core/src/test/kotlin/org/jacodb/testing/cfg/IRTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ class IRTest : BaseInstructionsTest() {
testClass(cp.findClass<BinarySearchTree<*>.BinarySearchTreeIterator>())
}

@Test
fun `get ir of random class`() {
val clazz = cp.findClass("kotlinx.coroutines.channels.ChannelsKt__DeprecatedKt\$filterIndexed\$1")
val method = clazz.declaredMethods.first { it.name == "invokeSuspend" }
JcGraphChecker(method, method.flowGraph()).check()
}

@Test
fun `get ir of self`() {
Expand All @@ -299,20 +305,22 @@ class IRTest : BaseInstructionsTest() {
}

// todo: make this test green
// @Test
@Test
fun `get ir of kotlinx-coroutines`() {
// testClass(cp.findClass("kotlinx.coroutines.ThreadContextElementKt"))
runAlongLib(kotlinxCoroutines)
runAlongLib(kotlinxCoroutines, false)
}



@AfterEach
fun printStats() {
cp.features!!.filterIsInstance<ClasspathCache>().forEach {
it.dumpStats()
}
}

private fun runAlongLib(file: File) {
private fun runAlongLib(file: File, validateLineNumbers: Boolean = true) {
println("Run along: ${file.absolutePath}")

val classes = JarLocation(file, isRuntime = false, object : JavaVersion {
Expand All @@ -324,7 +332,7 @@ class IRTest : BaseInstructionsTest() {
val clazz = cp.findClass(it.key)
if (!clazz.isAnnotation && !clazz.isInterface) {
println("Testing class: ${it.key}")
testClass(clazz)
testClass(clazz, validateLineNumbers)
}
}
}
Expand Down
98 changes: 98 additions & 0 deletions jacodb-core/src/test/kotlin/org/jacodb/testing/cfg/IincTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2022 UnitTestBot contributors (utbot.org)
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.jacodb.testing.cfg

import org.jacodb.api.ext.findClass
import org.jacodb.testing.WithDB
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test

class IincTest : BaseInstructionsTest() {

companion object : WithDB()

@Test
fun `iinc should work`() {
val clazz = cp.findClass<Incrementation>()

val javaClazz = testAndLoadClass(clazz)
val method = javaClazz.methods.first { it.name == "iinc" }
val res = method.invoke(null, 0)
assertEquals(0, res)
}

@Test
fun `iinc arrayIntIdx should work`() {
val clazz = cp.findClass<Incrementation>()

val javaClazz = testAndLoadClass(clazz)
val method = javaClazz.methods.first { it.name == "iincArrayIntIdx" }
val res = method.invoke(null)
assertArrayEquals(intArrayOf(1, 0, 2), res as IntArray)
}

@Test
fun `iinc arrayByteIdx should work`() {
val clazz = cp.findClass<Incrementation>()

val javaClazz = testAndLoadClass(clazz)
val method = javaClazz.methods.first { it.name == "iincArrayByteIdx" }
val res = method.invoke(null)
assertArrayEquals(intArrayOf(1, 0, 2), res as IntArray)
}

@Test
fun `iinc for`() {
val clazz = cp.findClass<Incrementation>()

val javaClazz = testAndLoadClass(clazz)
val method = javaClazz.methods.first { it.name == "iincFor" }
val res = method.invoke(null)
assertArrayEquals(intArrayOf(0, 1, 2, 3, 4), res as IntArray)
}

@Test
fun `iinc if`() {
val clazz = cp.findClass<Incrementation>()

val javaClazz = testAndLoadClass(clazz)
val method = javaClazz.methods.first { it.name == "iincIf" }
assertArrayEquals(intArrayOf(), method.invoke(null, true, true) as IntArray)
assertArrayEquals(intArrayOf(0), method.invoke(null, true, false) as IntArray)
}

@Test
fun `iinc if 2`() {
val clazz = cp.findClass<Incrementation>()

val javaClazz = testAndLoadClass(clazz)
val method = javaClazz.methods.first { it.name == "iincIf2" }
assertEquals(2, method.invoke(null, 1))
assertEquals(4, method.invoke(null, 2))
}

// @Test
fun `iinc while`() {
val clazz = cp.findClass<Incrementation>()

val javaClazz = testAndLoadClass(clazz)
val method = javaClazz.methods.first { it.name == "iincWhile" }
assertEquals(2, method.invoke(null) as IntArray)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -267,47 +267,6 @@ class InstructionsTest : BaseInstructionsTest() {
assertEquals("defaultMethod", callDefaultMethod.method.method.name)
}


@Test
fun `iinc should work`() {
val clazz = cp.findClass<Incrementation>()

val javaClazz = testAndLoadClass(clazz)
val method = javaClazz.methods.first { it.name == "iinc" }
val res = method.invoke(null, 0)
assertEquals(0, res)
}

@Test
fun `iinc arrayIntIdx should work`() {
val clazz = cp.findClass<Incrementation>()

val javaClazz = testAndLoadClass(clazz)
val method = javaClazz.methods.first { it.name == "iincArrayIntIdx" }
val res = method.invoke(null)
assertArrayEquals(intArrayOf(1, 0, 2), res as IntArray)
}

@Test
fun `iinc arrayByteIdx should work`() {
val clazz = cp.findClass<Incrementation>()

val javaClazz = testAndLoadClass(clazz)
val method = javaClazz.methods.first { it.name == "iincArrayByteIdx" }
val res = method.invoke(null)
assertArrayEquals(intArrayOf(1, 0, 2), res as IntArray)
}

@Test
fun `iinc for`() {
val clazz = cp.findClass<Incrementation>()

val javaClazz = testAndLoadClass(clazz)
val method = javaClazz.methods.first { it.name == "iincFor" }
val res = method.invoke(null)
assertArrayEquals(intArrayOf(0, 1, 2, 3, 4), res as IntArray)
}

}

fun JcMethod.dumpInstructions(): String {
Expand Down
Loading

0 comments on commit 27a25a0

Please sign in to comment.