diff --git a/api/analyzers/source_analyzer.py b/api/analyzers/source_analyzer.py index 4186f358..9b363cf7 100644 --- a/api/analyzers/source_analyzer.py +++ b/api/analyzers/source_analyzer.py @@ -41,36 +41,52 @@ def supported_types(self) -> list[str]: """ return list(analyzers.keys()) - def create_entity_hierarchy(self, entity: Entity, file: File, analyzer: AbstractAnalyzer, graph: Graph): + def create_entity_hierarchy(self, entity: Entity, file: File, analyzer: AbstractAnalyzer, + pending_entities: list, pending_rels: list): types = analyzer.get_entity_types() stack = list(entity.node.children) while stack: node = stack.pop() if node.type in types: child = Entity(node) - child.id = graph.add_entity(analyzer.get_entity_label(node), analyzer.get_entity_name(node), analyzer.get_entity_docstring(node), str(file.path), node.start_point.row, node.end_point.row, {}) + pending_entities.append(( + child, analyzer.get_entity_label(node), + analyzer.get_entity_name(node), + analyzer.get_entity_docstring(node), + str(file.path), node.start_point.row, + node.end_point.row, {} + )) if not analyzer.is_dependency(str(file.path)): analyzer.add_symbols(child) file.add_entity(child) entity.add_child(child) - graph.connect_entities("DEFINES", entity.id, child.id) - self.create_entity_hierarchy(child, file, analyzer, graph) + pending_rels.append(("DEFINES", entity, child)) + self.create_entity_hierarchy(child, file, analyzer, + pending_entities, pending_rels) else: stack.extend(node.children) - def create_hierarchy(self, file: File, analyzer: AbstractAnalyzer, graph: Graph): + def create_hierarchy(self, file: File, analyzer: AbstractAnalyzer, + pending_entities: list, pending_rels: list): types = analyzer.get_entity_types() stack = [file.tree.root_node] while stack: node = stack.pop() if node.type in types: entity = Entity(node) - entity.id = graph.add_entity(analyzer.get_entity_label(node), analyzer.get_entity_name(node), analyzer.get_entity_docstring(node), str(file.path), node.start_point.row, node.end_point.row, {}) + pending_entities.append(( + entity, analyzer.get_entity_label(node), + analyzer.get_entity_name(node), + analyzer.get_entity_docstring(node), + str(file.path), node.start_point.row, + node.end_point.row, {} + )) if not analyzer.is_dependency(str(file.path)): analyzer.add_symbols(entity) file.add_entity(entity) - graph.connect_entities("DEFINES", file.id, entity.id) - self.create_entity_hierarchy(entity, file, analyzer, graph) + pending_rels.append(("DEFINES", file, entity)) + self.create_entity_hierarchy(entity, file, analyzer, + pending_entities, pending_rels) else: stack.extend(node.children) @@ -87,6 +103,11 @@ def first_pass(self, path: Path, files: list[Path], ignore: list[str], graph: Gr for ext in set([file.suffix for file in files if file.suffix in supoorted_types]): analyzers[ext].add_dependencies(path, files) + # Phase 1: Parse files and build in-memory hierarchy + pending_files = [] + pending_entities = [] + pending_rels = [] + files_len = len(files) for i, file_path in enumerate(files): # Skip none supported files @@ -95,7 +116,7 @@ def first_pass(self, path: Path, files: list[Path], ignore: list[str], graph: Gr continue # Skip ignored files - if any([i in str(file_path) for i in ignore]): + if any(ig in str(file_path) for ig in ignore): logging.info(f"Skipping ignored file {file_path}") continue @@ -110,10 +131,17 @@ def first_pass(self, path: Path, files: list[Path], ignore: list[str], graph: Gr # Create file entity file = File(file_path, tree) self.files[file_path] = file + pending_files.append(file) + + # Walk through the AST and collect entities/relationships + self.create_hierarchy(file, analyzer, pending_entities, pending_rels) - # Walk thought the AST - graph.add_file(file) - self.create_hierarchy(file, analyzer, graph) + # Phase 2: Batch insert files, entities, and relationships + graph.add_files_batch(pending_files) + graph.add_entities_batch(pending_entities) + graph.connect_entities_batch([ + (rel, src.id, dest.id, {}) for rel, src, dest in pending_rels + ]) def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: """ @@ -144,8 +172,11 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: else: lsps[".cs"] = NullLanguageServer() with lsps[".java"].start_server(), lsps[".py"].start_server(), lsps[".cs"].start_server(): + pending_rels = [] files_len = len(self.files) for i, file_path in enumerate(files): + if file_path not in self.files: + continue file = self.files[file_path] logging.info(f'Processing file ({i + 1}/{files_len}): {file_path}') for _, entity in file.entities.items(): @@ -155,18 +186,25 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: if len(symbol.resolved_symbol) == 0: continue resolved_symbol = next(iter(symbol.resolved_symbol)) + rel = None + props = {} if key == "base_class": - graph.connect_entities("EXTENDS", entity.id, resolved_symbol.id) + rel = "EXTENDS" elif key == "implement_interface": - graph.connect_entities("IMPLEMENTS", entity.id, resolved_symbol.id) + rel = "IMPLEMENTS" elif key == "extend_interface": - graph.connect_entities("EXTENDS", entity.id, resolved_symbol.id) + rel = "EXTENDS" elif key == "call": - graph.connect_entities("CALLS", entity.id, resolved_symbol.id, {"line": symbol.symbol.start_point.row, "text": symbol.symbol.text.decode("utf-8")}) + rel = "CALLS" + props = {"line": symbol.symbol.start_point.row, "text": symbol.symbol.text.decode("utf-8")} elif key == "return_type": - graph.connect_entities("RETURNS", entity.id, resolved_symbol.id) + rel = "RETURNS" elif key == "parameters": - graph.connect_entities("PARAMETERS", entity.id, resolved_symbol.id) + rel = "PARAMETERS" + if rel: + pending_rels.append((rel, entity.id, resolved_symbol.id, props)) + + graph.connect_entities_batch(pending_rels) def analyze_files(self, files: list[Path], path: Path, graph: Graph) -> None: self.first_pass(path, files, [], graph) diff --git a/api/graph.py b/api/graph.py index eda72e63..95361464 100644 --- a/api/graph.py +++ b/api/graph.py @@ -1,10 +1,18 @@ import os +import re import time +from collections import defaultdict from .entities import * from typing import Optional from falkordb import FalkorDB, Path, Node, QueryResult from falkordb.asyncio import FalkorDB as AsyncFalkorDB +# Maximum items per UNWIND batch to avoid overwhelming FalkorDB/Redis +BATCH_SIZE = 500 + +# Regex to validate graph labels/relation types (alphanumeric + underscore only) +_VALID_LABEL_RE = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$') + # Configure the logger import logging logging.basicConfig(level=logging.DEBUG, @@ -248,6 +256,9 @@ def add_entity(self, label: str, name: str, doc: str, path: str, src_start: int, Args: """ + if not _VALID_LABEL_RE.match(label): + raise ValueError(f"Invalid entity label: {label!r}") + q = f"""MERGE (c:{label}:Searchable {{name: $name, path: $path, src_start: $src_start, src_end: $src_end}}) SET c.doc = $doc @@ -267,6 +278,47 @@ def add_entity(self, label: str, name: str, doc: str, path: str, src_start: int, node = res.result_set[0][0] return node.id + def add_entities_batch(self, entities_data: list) -> None: + """ + Batch add entity nodes to the graph database using UNWIND. + Groups by label, then processes in chunks of BATCH_SIZE. + + Args: + entities_data: list of tuples + (entity_obj, label, name, doc, path, src_start, src_end, props) + entity_obj.id will be set after insertion. + """ + + if not entities_data: + return + + by_label = defaultdict(list) + for item in entities_data: + by_label[item[1]].append(item) + + for label, group in by_label.items(): + if not _VALID_LABEL_RE.match(label): + raise ValueError(f"Invalid entity label: {label!r}") + + q = f"""UNWIND $entities AS e + MERGE (c:{label}:Searchable {{name: e['name'], path: e['path'], + src_start: e['src_start'], + src_end: e['src_end']}}) + SET c.doc = e['doc'] + SET c += e['props'] + RETURN c""" + + for start in range(0, len(group), BATCH_SIZE): + chunk = group[start:start + BATCH_SIZE] + data = [{ + 'name': item[2], 'doc': item[3], 'path': item[4], + 'src_start': item[5], 'src_end': item[6], 'props': item[7] + } for item in chunk] + + res = self._query(q, {'entities': data}) + for j, item in enumerate(chunk): + item[0].id = res.result_set[j][0].id + def get_class_by_name(self, class_name: str) -> Optional[Node]: q = "MATCH (c:Class) WHERE c.name = $name RETURN c LIMIT 1" res = self._query(q, {'name': class_name}).result_set @@ -406,6 +458,30 @@ def add_file(self, file: File) -> None: node = res.result_set[0][0] file.id = node.id + def add_files_batch(self, files: list[File]) -> None: + """ + Batch add file nodes to the graph database using UNWIND. + Processes in chunks of BATCH_SIZE to avoid oversized queries. + + Args: + files: list of File objects. Each file.id will be set after insertion. + """ + + if not files: + return + + q = """UNWIND $files AS fd + MERGE (f:File:Searchable {path: fd['path'], name: fd['name'], ext: fd['ext']}) + RETURN f""" + + for start in range(0, len(files), BATCH_SIZE): + chunk = files[start:start + BATCH_SIZE] + file_data = [{'path': str(f.path), 'name': f.path.name, 'ext': f.path.suffix} + for f in chunk] + res = self._query(q, {'files': file_data}) + for i, row in enumerate(res.result_set): + chunk[i].id = row[0].id + def delete_files(self, files: list[Path]) -> tuple[str, dict, list[int]]: """ Deletes file(s) from the graph in addition to any other entity @@ -485,6 +561,44 @@ def connect_entities(self, relation: str, src_id: int, dest_id: int, properties: params = {'src_id': src_id, 'dest_id': dest_id, "properties": properties} self._query(q, params) + def connect_entities_batch(self, relationships: list[tuple[str, int, int, dict]]) -> None: + """ + Batch create relationships between entities using UNWIND. + Groups by relation type, then processes in chunks of BATCH_SIZE. + + Args: + relationships: list of (relation, src_id, dest_id, properties) + """ + + if not relationships: + return + + by_relation = defaultdict(list) + for rel in relationships: + if rel[1] is None or rel[2] is None: + logging.warning(f"Skipping relationship {rel[0]} with None ID: src={rel[1]}, dest={rel[2]}") + continue + by_relation[rel[0]].append(rel) + + for relation, group in by_relation.items(): + if not _VALID_LABEL_RE.match(relation): + raise ValueError(f"Invalid relation type: {relation!r}") + + q = f"""UNWIND $rels AS r + MATCH (src) + WHERE ID(src) = r['src_id'] + MATCH (dest) + WHERE ID(dest) = r['dest_id'] + MERGE (src)-[e:{relation}]->(dest) + SET e += r['properties'] + RETURN e""" + + for start in range(0, len(group), BATCH_SIZE): + chunk = group[start:start + BATCH_SIZE] + data = [{'src_id': r[1], 'dest_id': r[2], 'properties': r[3]} + for r in chunk] + self._query(q, {'rels': data}) + def function_calls_function(self, caller_id: int, callee_id: int, pos: int) -> None: """ Establish a 'CALLS' relationship between two function nodes. diff --git a/tests/test_graph_ops.py b/tests/test_graph_ops.py index aa137832..66e8db04 100644 --- a/tests/test_graph_ops.py +++ b/tests/test_graph_ops.py @@ -69,5 +69,91 @@ def test_function_calls_function(self): res = self.g.query(query, params).result_set self.assertTrue(res[0][0]) + def test_add_files_batch(self): + files = [File(Path(f'/batch/file{i}.py'), None) for i in range(5)] + self.graph.add_files_batch(files) + + for i, f in enumerate(files): + self.assertIsNotNone(f.id) + result = self.graph.get_file(f'/batch/file{i}.py', f'file{i}.py', '.py') + self.assertIsNotNone(result) + self.assertEqual(result.properties['name'], f'file{i}.py') + + def test_add_files_batch_empty(self): + self.graph.add_files_batch([]) + + def test_add_entities_batch(self): + from unittest.mock import MagicMock + + entities_data = [] + for i in range(3): + mock_entity = MagicMock() + mock_entity.id = None + entities_data.append(( + mock_entity, 'Function', f'func_{i}', f'doc {i}', + '/batch/path', i * 10, i * 10 + 5, {} + )) + + self.graph.add_entities_batch(entities_data) + + for item in entities_data: + self.assertIsNotNone(item[0].id) + + def test_connect_entities_batch(self): + file = File(Path('/batch/connect_test.py'), None) + self.graph.add_file(file) + + func_a_id = self.graph.add_entity( + 'Function', 'batch_a', '', '/batch/connect_test.py', 1, 5, {} + ) + func_b_id = self.graph.add_entity( + 'Function', 'batch_b', '', '/batch/connect_test.py', 6, 10, {} + ) + func_c_id = self.graph.add_entity( + 'Function', 'batch_c', '', '/batch/connect_test.py', 11, 15, {} + ) + + self.graph.connect_entities_batch([ + ("DEFINES", file.id, func_a_id, {}), + ("DEFINES", file.id, func_b_id, {}), + ("DEFINES", file.id, func_c_id, {}), + ("CALLS", func_a_id, func_b_id, {"line": 3, "text": "batch_b()"}), + ]) + + # Verify DEFINES relationships + q = """MATCH (f:File)-[:DEFINES]->(fn:Function) + WHERE ID(f) = $file_id + RETURN count(fn)""" + res = self.g.query(q, {'file_id': file.id}).result_set + self.assertEqual(res[0][0], 3) + + # Verify CALLS relationship with properties + q = """MATCH (a:Function)-[c:CALLS]->(b:Function) + WHERE ID(a) = $a_id AND ID(b) = $b_id + RETURN c.line, c.text""" + res = self.g.query(q, {'a_id': func_a_id, 'b_id': func_b_id}).result_set + self.assertEqual(res[0][0], 3) + self.assertEqual(res[0][1], "batch_b()") + + def test_connect_entities_batch_empty(self): + self.graph.connect_entities_batch([]) + + def test_batch_chunking(self): + """Verify batches are correctly chunked when exceeding BATCH_SIZE.""" + import api.graph as graph_module + original = graph_module.BATCH_SIZE + try: + graph_module.BATCH_SIZE = 3 + files = [File(Path(f'/chunked/f{i}.py'), None) for i in range(7)] + self.graph.add_files_batch(files) + for f in files: + self.assertIsNotNone(f.id) + # Verify all 7 files are actually in the DB + for i in range(7): + result = self.graph.get_file(f'/chunked/f{i}.py', f'f{i}.py', '.py') + self.assertIsNotNone(result) + finally: + graph_module.BATCH_SIZE = original + if __name__ == '__main__': unittest.main()