Skip to content

Commit 9e40671

Browse files
committed
significantly optimized and fixed infinite recursion in the privilege analyzer tool
1 parent bf03be5 commit 9e40671

File tree

1 file changed

+81
-93
lines changed

1 file changed

+81
-93
lines changed

tools/privilege_violation_analyzer/call_graph.py

Lines changed: 81 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import defaultdict
1+
from collections import defaultdict, deque
22
from symbol_analyzer import SymbolAnalyzer
33
from colorama import Fore, Style, init as colorama_init
44

@@ -19,7 +19,6 @@ def __init__(self, symbols):
1919
self.call_graph = defaultdict(list) # Maps caller address to list of callee edges
2020
self.addr_to_symbol = {sym['address']: sym for sym in symbols}
2121
self.called_addresses = set()
22-
self.call_paths = None
2322
self.privilege_violations = []
2423
self.privilege_warnings = []
2524

@@ -54,7 +53,6 @@ def build_graph(self):
5453
except ValueError:
5554
pass
5655

57-
self.call_paths = self.generate_call_paths()
5856
self.analyze_privilege_violations()
5957

6058
def find_root_functions(self):
@@ -67,100 +65,90 @@ def find_root_functions(self):
6765
root_functions.append(sym)
6866
return root_functions
6967

70-
def generate_call_paths(self):
68+
def analyze_privilege_violations(self):
7169
"""
72-
Generate all call paths starting from root functions.
73-
Returns a list of paths, each represented as a list of tuples (function name, privilege, address).
70+
Analyze the call graph and detect privilege violations and warnings using BFS.
71+
This approach avoids enumerating all possible paths and is more efficient.
7472
"""
73+
self.privilege_violations = []
74+
self.privilege_warnings = []
75+
seen_violations = set()
76+
seen_warnings = set()
77+
78+
# Track visited nodes to prevent infinite loops
79+
visited = set()
80+
81+
def bfs_analyze_violations(start_addr):
82+
"""
83+
Use BFS to analyze privilege violations from a starting function.
84+
"""
85+
if start_addr in visited:
86+
return
87+
88+
queue = deque([(start_addr, [], False)]) # (addr, path, inherited_privilege)
89+
visited.add(start_addr)
90+
91+
while queue:
92+
current_addr, current_path, inherited_privilege = queue.popleft()
93+
94+
if current_addr not in self.addr_to_symbol:
95+
continue
96+
97+
symbol = self.addr_to_symbol[current_addr]
98+
current_privilege = inherited_privilege or symbol['privileged']
99+
100+
# Add current function to path
101+
current_path = current_path + [(symbol['name'], current_privilege, current_addr)]
102+
103+
# Check all outgoing edges for violations
104+
for edge in self.call_graph.get(current_addr, []):
105+
callee_addr = edge["callee"]
106+
elevated_call_privilege = edge.get("elevated_call_privilege", False)
107+
108+
if callee_addr not in self.addr_to_symbol:
109+
continue
110+
111+
callee_symbol = self.addr_to_symbol[callee_addr]
112+
callee_privileged = callee_symbol['privileged']
113+
114+
# Detect privilege violation
115+
if not current_privilege and callee_privileged and not elevated_call_privilege:
116+
violation_key = (symbol['name'], current_addr, callee_symbol['name'], callee_addr)
117+
if violation_key not in seen_violations:
118+
seen_violations.add(violation_key)
119+
violation = {
120+
"caller": {"name": symbol['name'], "address": current_addr},
121+
"callee": {"name": callee_symbol['name'], "address": callee_addr},
122+
"path": current_path + [(callee_symbol['name'], callee_privileged, callee_addr)],
123+
}
124+
self.privilege_violations.append(violation)
125+
126+
# Detect privilege warning
127+
if current_privilege and callee_privileged and not elevated_call_privilege and not symbol['name'].startswith("dynpriv::"):
128+
warning_key = (symbol['name'], current_addr, callee_symbol['name'], callee_addr)
129+
if warning_key not in seen_warnings:
130+
seen_warnings.add(warning_key)
131+
warning = {
132+
"caller": {"name": symbol['name'], "address": current_addr},
133+
"callee": {"name": callee_symbol['name'], "address": callee_addr},
134+
"path": current_path + [(callee_symbol['name'], callee_privileged, callee_addr)],
135+
}
136+
self.privilege_warnings.append(warning)
137+
138+
# Continue BFS if we haven't visited this callee yet
139+
if callee_addr not in visited:
140+
visited.add(callee_addr)
141+
queue.append((callee_addr, current_path, current_privilege))
142+
143+
# Start analysis from all root functions
75144
root_functions = self.find_root_functions()
76-
paths = []
77-
78-
def dfs(current_path, current_addr, visited, is_privileged):
79-
if current_addr in visited:
80-
return # Prevent cycles
81-
if current_addr not in self.addr_to_symbol:
82-
return # Skip unknown addresses
83-
84-
visited.add(current_addr)
85-
symbol = self.addr_to_symbol[current_addr]
86-
current_privilege = is_privileged or symbol['privileged']
87-
current_path.append((symbol['name'], current_privilege, current_addr))
88-
89-
if current_addr not in self.call_graph:
90-
paths.append(list(current_path)) # Leaf node, save the path
91-
else:
92-
for edge in self.call_graph[current_addr]:
93-
dfs(current_path, edge["callee"], visited, current_privilege)
94-
95-
current_path.pop()
96-
visited.remove(current_addr)
97-
98145
for root in root_functions:
99-
dfs([], root['address'], set(), root['privileged'])
100-
101-
return paths
102-
103-
def analyze_privilege_violations(self):
104-
"""
105-
Analyze the call paths and detect privilege violations and warnings.
106-
A privilege violation occurs when:
107-
- An unprivileged function calls a privileged function.
108-
109-
A privilege warning occurs when:
110-
- An unprivileged function that inherited privilege in the current call chain calls a privileged function.
111-
112-
Stores violations in `self.privilege_violations` and warnings in `self.privilege_warnings`.
113-
"""
114-
self.privilege_violations = [] # Clear previous violations
115-
self.privilege_warnings = [] # Clear previous warnings
116-
seen_violations = set() # Track unique violations
117-
seen_warnings = set() # Track unique warnings
118-
119-
def check_for_violations_and_warnings(path):
120-
for i in range(len(path) - 1):
121-
caller_name, caller_privileged, caller_addr = path[i]
122-
callee_name, callee_privileged, callee_addr = path[i + 1]
123-
124-
# Find the corresponding call edge for elevated privilege check
125-
edge_key = (caller_addr, callee_addr)
126-
call_edges = self.call_graph.get(caller_addr, [])
127-
elevated_call_privilege = any(
128-
edge["callee"] == callee_addr and edge.get("elevated_call_privilege", False)
129-
for edge in call_edges
130-
)
131-
132-
# Detect privilege violation
133-
if not caller_privileged and callee_privileged and not elevated_call_privilege:
134-
violation_key = (
135-
caller_name, caller_addr, callee_name, callee_addr,
136-
tuple((name, privileged) for name, privileged, _ in path[:i + 2]) # Path snapshot
137-
)
138-
if violation_key not in seen_violations:
139-
seen_violations.add(violation_key)
140-
violation = {
141-
"caller": {"name": caller_name, "address": caller_addr},
142-
"callee": {"name": callee_name, "address": callee_addr},
143-
"path": path[:i + 2], # The path leading to the violation
144-
}
145-
self.privilege_violations.append(violation)
146-
147-
# Detect privilege warning
148-
if caller_privileged and not elevated_call_privilege and not caller_name.startswith("dynpriv::"):
149-
warning_key = (
150-
caller_name, caller_addr, callee_name, callee_addr,
151-
tuple((name, privileged) for name, privileged, _ in path[:i + 2]) # Path snapshot
152-
)
153-
if callee_privileged and warning_key not in seen_warnings:
154-
seen_warnings.add(warning_key)
155-
warning = {
156-
"caller": {"name": caller_name, "address": caller_addr},
157-
"callee": {"name": callee_name, "address": callee_addr},
158-
"path": path[:i + 2], # The path leading to the warning
159-
}
160-
self.privilege_warnings.append(warning)
161-
162-
for path in self.call_paths:
163-
check_for_violations_and_warnings(path)
146+
bfs_analyze_violations(root['address'])
147+
148+
# Also analyze from any remaining unvisited functions (in case of cycles)
149+
for sym in self.symbols:
150+
if sym['address'] not in visited:
151+
bfs_analyze_violations(sym['address'])
164152

165153
def print_call_graph_tree(self, sym=None):
166154
"""

0 commit comments

Comments
 (0)