Skip to content

Commit 919c57e

Browse files
committed
Add GNN path selector
1 parent 1ccd345 commit 919c57e

File tree

18 files changed

+631
-8
lines changed

18 files changed

+631
-8
lines changed

Game_env/test_model.onnx

80.2 KB
Binary file not shown.

buildSrc/src/main/kotlin/Versions.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ object Versions {
77
const val jcdb = "1.2.0"
88
const val mockk = "1.13.4"
99
const val junitParams = "5.9.3"
10+
const val serialization = "1.5.1"
11+
const val onnxruntime = "1.15.1"
1012
const val logback = "1.4.8"
1113

1214
// versions for jvm samples

buildSrc/src/main/kotlin/usvm.kotlin-conventions.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies {
2020
implementation(kotlin("stdlib-jdk8"))
2121
implementation(kotlin("reflect"))
2222
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutines}")
23+
implementation("com.microsoft.onnxruntime", "onnxruntime", Versions.onnxruntime)
2324

2425
testImplementation(kotlin("test"))
2526
}

usvm-core/src/main/kotlin/org/usvm/PathTrieNode.kt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ sealed class PathsTrieNode<State : UState<*, *, Statement, *, *, State>, Stateme
3434
*/
3535
abstract val depth: Int
3636

37+
/**
38+
* States that were forked from this node
39+
*/
40+
abstract val accumulatedForks: MutableCollection<State>
41+
3742
/**
3843
* Adds a new label to [labels] collection.
3944
*/
@@ -77,9 +82,10 @@ class PathsTrieNodeImpl<State : UState<*, *, Statement, *, *, State>, Statement>
7782
depth = parentNode.depth + 1,
7883
parent = parentNode,
7984
states = hashSetOf(state),
80-
statement = statement
85+
statement = statement,
8186
) {
8287
parentNode.children[statement] = this
88+
parentNode.accumulatedForks.addAll(this.states)
8389
}
8490

8591
internal constructor(parentNode: PathsTrieNodeImpl<State, Statement>, statement: Statement, state: State) : this(
@@ -89,11 +95,14 @@ class PathsTrieNodeImpl<State : UState<*, *, Statement, *, *, State>, Statement>
8995
statement = statement
9096
) {
9197
parentNode.children[statement] = this
98+
parentNode.accumulatedForks.addAll(this.states)
9299
parentNode.states -= state
93100
}
94101

95102
override val labels: MutableSet<Any> = hashSetOf()
96103

104+
override val accumulatedForks: MutableCollection<State> = mutableSetOf()
105+
97106
override fun addLabel(label: Any) {
98107
labels.add(label)
99108
}
@@ -115,6 +124,8 @@ class RootNode<State : UState<*, *, Statement, *, *, State>, Statement> : PathsT
115124

116125
override val labels: MutableSet<Any> = hashSetOf()
117126

127+
override val accumulatedForks: MutableCollection<State> = mutableSetOf()
128+
118129
override val depth: Int = 0
119130

120131
override fun addLabel(label: Any) {
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
package org.usvm.ps
2+
3+
import org.usvm.statistics.ApplicationGraph
4+
5+
data class Block<Statement>(
6+
val id: Int,
7+
var path: MutableList<Statement> = mutableListOf(),
8+
9+
var parents: MutableSet<Block<Statement>> = mutableSetOf(),
10+
var children: MutableSet<Block<Statement>> = mutableSetOf()
11+
) {
12+
override fun hashCode(): Int = id
13+
14+
override fun equals(other: Any?): Boolean {
15+
if (this === other) return true
16+
if (javaClass != other?.javaClass) return false
17+
18+
other as Block<*>
19+
20+
if (id != other.id) return false
21+
22+
return true
23+
}
24+
}
25+
26+
class BlockGraph<Method, Statement>(
27+
initialStatement: Statement,
28+
private val applicationGraph: ApplicationGraph<Method, Statement>,
29+
) {
30+
val root: Block<Statement>
31+
private var nextBlockId: Int = 0
32+
private val blockStatementMapping = HashMap<Block<Statement>, MutableList<Statement>>()
33+
34+
val blocks: Collection<Block<Statement>>
35+
get() = blockStatementMapping.keys
36+
37+
init {
38+
root = buildGraph(initialStatement)
39+
}
40+
41+
fun getGraphBlock(statement: Statement): Block<Statement>? {
42+
blockStatementMapping.forEach {
43+
if (statement in it.value) {
44+
return it.key
45+
}
46+
}
47+
return null
48+
}
49+
50+
private fun initializeGraphBlockWith(statement: Statement): Block<Statement> {
51+
val currentBlock = Block(nextBlockId++, path = mutableListOf(statement))
52+
blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(statement)
53+
return currentBlock
54+
}
55+
56+
private fun createAndLinkWithPreds(statement: Statement): Block<Statement> {
57+
val currentBlock = initializeGraphBlockWith(statement)
58+
for (pred in applicationGraph.predecessors(statement)) {
59+
getGraphBlock(pred)?.children?.add(currentBlock)
60+
getGraphBlock(pred)?.let { currentBlock.parents.add(it) }
61+
}
62+
return currentBlock
63+
}
64+
65+
private fun Statement.inBlock() = getGraphBlock(this) != null
66+
67+
private fun ApplicationGraph<Method, Statement>.filterStmtSuccsNotInBlock(
68+
statement: Statement,
69+
forceNewBlock: Boolean
70+
): Sequence<Pair<Statement, Boolean>> {
71+
return this.successors(statement).filter { !it.inBlock() }.map { Pair(it, forceNewBlock) }
72+
}
73+
74+
fun buildGraph(initial: Statement): Block<Statement> {
75+
val root = initializeGraphBlockWith(initial)
76+
var currentBlock = root
77+
val statementQueue = ArrayDeque<Pair<Statement, Boolean>>()
78+
79+
val initialHasMultipleSuccessors = applicationGraph.successors(initial).count() > 1
80+
statementQueue.addAll(
81+
applicationGraph.filterStmtSuccsNotInBlock(
82+
initial,
83+
forceNewBlock = initialHasMultipleSuccessors
84+
)
85+
)
86+
87+
while (statementQueue.isNotEmpty()) {
88+
val (currentStatement, forceNew) = statementQueue.removeFirst()
89+
90+
if (forceNew) {
91+
// don't need to add `currentStatement` succs, we did it earlier
92+
createAndLinkWithPreds(currentStatement)
93+
continue
94+
}
95+
96+
// if statement is a call or if statement has multiple successors: next statements start new block
97+
if (applicationGraph.callees(currentStatement).any() || applicationGraph.successors(currentStatement).count() > 1) {
98+
currentBlock.path.add(currentStatement)
99+
blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(currentStatement)
100+
statementQueue.addAll(applicationGraph.filterStmtSuccsNotInBlock(currentStatement, forceNewBlock = true))
101+
continue
102+
}
103+
104+
// if statement has multiple ins: next statements start new block
105+
if (applicationGraph.predecessors(currentStatement).count() > 1) {
106+
currentBlock = createAndLinkWithPreds(currentStatement)
107+
blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(currentStatement)
108+
statementQueue.addAll(applicationGraph.filterStmtSuccsNotInBlock(currentStatement, forceNewBlock = true))
109+
continue
110+
}
111+
112+
currentBlock.path.add(currentStatement)
113+
blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(currentStatement)
114+
statementQueue.addAll(applicationGraph.filterStmtSuccsNotInBlock(currentStatement, forceNewBlock = false))
115+
}
116+
117+
return root
118+
}
119+
120+
fun getEdges(): List<GameMapEdge> {
121+
return blocks.flatMap { block ->
122+
block.children.map { GameMapEdge(it.id, block.id, GameEdgeLabel(0)) }
123+
}
124+
}
125+
126+
fun getVertices(): Collection<Block<Statement>> = blocks
127+
128+
fun getBlockFeatures(
129+
block: Block<Statement>, isCovered: (Statement) -> Boolean,
130+
inCoverageZone: (Statement) -> Boolean,
131+
isVisited: (Statement) -> Boolean,
132+
stateIdsInBlock: List<UInt>
133+
): BlockFeatures {
134+
val firstStatement = block.path.first()
135+
val lastStatement = block.path.last()
136+
val visitedByState = isVisited(lastStatement)
137+
val touchedByState = visitedByState || isVisited(firstStatement)
138+
139+
return BlockFeatures(
140+
id = block.id,
141+
inCoverageZone = inCoverageZone(firstStatement),
142+
basicBlockSize = block.path.size,
143+
coveredByTest = isCovered(firstStatement),
144+
visitedByState = visitedByState,
145+
touchedByState = touchedByState,
146+
states = stateIdsInBlock
147+
)
148+
}
149+
}

0 commit comments

Comments
 (0)