diff --git a/api/analyzers/c/__init__.py b/api/analyzers/c/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/analyzers/c/analyzer.py b/api/analyzers/c/analyzer.py index aa25420a..6aa2af9e 100644 --- a/api/analyzers/c/analyzer.py +++ b/api/analyzers/c/analyzer.py @@ -1,476 +1,120 @@ -# import os -# from pathlib import Path - -# from multilspy import SyncLanguageServer -# from ...entities import * -# from ...graph import Graph -# from typing import Optional -# from ..analyzer import AbstractAnalyzer - -# import tree_sitter_c as tsc -# from tree_sitter import Language, Node - -# import logging -# logger = logging.getLogger('code_graph') - -# class CAnalyzer(AbstractAnalyzer): -# def __init__(self) -> None: -# super().__init__(Language(tsc.language())) - -# def get_entity_label(self, node: Node) -> str: -# if node.type == 'struct_specifier': -# return "Struct" -# elif node.type == 'function_definition': -# return "Function" -# raise ValueError(f"Unknown entity type: {node.type}") - -# def get_entity_name(self, node: Node) -> str: -# if node.type in ['struct_specifier', 'function_definition']: -# return node.child_by_field_name('name').text.decode('utf-8') -# raise ValueError(f"Unknown entity type: {node.type}") - -# def get_entity_docstring(self, node: Node) -> Optional[str]: -# if node.type in ['struct_specifier', 'function_definition']: -# body = node.child_by_field_name('body') -# if body.child_count > 0 and body.children[0].type == 'expression_statement': -# docstring_node = body.children[0].child(0) -# return docstring_node.text.decode('utf-8') -# return None -# raise ValueError(f"Unknown entity type: {node.type}") - -# def get_entity_types(self) -> list[str]: -# return ['struct_specifier', 'function_definition'] - -# def process_pointer_declaration(self, node: Node) -> tuple[str, int]: -# """ -# Processes a pointer declaration node to determine the argument name and pointer count. - -# Args: -# node (Node): The AST node representing a pointer declaration. - -# Returns: -# Tuple[str, int]: A tuple containing the argument name and the pointer count. -# """ - -# assert(node.type == 'pointer_declarator') - -# text = node.text.decode('utf-8') -# idx = max(text.rfind(' '), text.rfind('*')) + 1 -# name = text[idx:] -# t = text[:idx] - -# return (t, name) - -# def process_parameter_declaration(self, node: Node) -> tuple[bool, str, int, str]: -# """ -# Processes a parameter declaration node to determine its properties. - -# Args: -# node (Node): The AST node representing a parameter declaration. - -# Returns: -# Tuple[bool, str, int, Optional[str]]: A tuple containing: -# - A boolean indicating if the parameter is const. -# - A string representing the argument type. -# - An integer representing the pointer count. -# - An optional string for the argument name (None if not found). -# """ - -# assert(node.type == 'parameter_declaration') - -# const = False -# pointer = 0 -# arg_name = '' -# arg_type = '' - -# for child in node.children: -# t = child.type - -# if t == 'type_qualifier': -# child = child.children[0] -# if child.type == 'const': -# const = True - -# elif t == 'type_identifier': -# arg_type = child.text.decode('utf-8') - -# elif t == 'identifier': -# arg_name = child.text.decode('utf-8') - -# elif t == 'primitive_type': -# arg_type = child.text.decode('utf-8') - -# elif t == 'pointer_declarator': -# pointer_arg_name, arg_name = self.process_pointer_declaration(child) -# arg_type += pointer_arg_name - -# elif t == 'sized_type_specifier': -# arg_type = child.text.decode('utf-8') - -# return (const, arg_type, pointer, arg_name) - -# def process_function_definition_node(self, node: Node, path: Path, -# source_code: str) -> Optional[Function]: -# """ -# Processes a function definition node to extract function details. - -# Args: -# node (Node): The AST node representing a function definition. -# path (Path): The file path where the function is defined. - -# Returns: -# Optional[Function]: A Function object containing details about the function, or None if the function name cannot be determined. -# """ - -# # Extract function name -# res = find_child_of_type(node, 'function_declarator') -# if res is None: -# return None - -# function_declarator = res[0] - -# res = find_child_of_type(function_declarator, 'identifier') -# if res is None: -# return None - -# identifier = res[0] -# function_name = identifier.text.decode('utf-8') -# logger.info(f"Function declaration: {function_name}") - -# # Extract function return type -# res = find_child_of_type(node, 'primitive_type') -# ret_type = 'Unknown' -# if res is not None: -# ret_type = res[0] -# ret_type = ret_type.text.decode('utf-8') - -# # Extract function parameters -# args = [] -# res = find_child_of_type(function_declarator, 'parameter_list') -# if res is not None: -# parameters = res[0] - -# # Extract arguments and their types -# for child in parameters.children: -# if child.type == 'parameter_declaration': -# arg = self.process_parameter_declaration(child) -# args.append(arg) - -# # Extract function definition line numbers -# start_line = node.start_point[0] -# end_line = node.end_point[0] - -# # Create Function object -# docs = '' -# src = source_code[node.start_byte:node.end_byte] -# f = Function(str(path), function_name, docs, ret_type, src, start_line, end_line) - -# # Add arguments to Function object -# for arg in args: -# const = arg[0] -# t = arg[1] -# pointer = arg[2] -# name = arg[3] - -# # Skip f(void) -# if name is None and t == 'void': -# continue - -# type_str = 'const ' if const else '' -# type_str += t -# type_str += '*' * pointer - -# f.add_argument(name, type_str) - -# return f - -# def process_function_definition(self, parent: File, node: Node, path: Path, -# graph: Graph, source_code: str) -> None: -# """ -# Processes a function definition node and adds it to the graph. - -# Args: -# parent (File): The parent File object. -# node (Node): The AST node representing the function definition. -# path (Path): The file path where the function is defined. -# graph (Graph): The Graph object to which the function entity will be added. - -# Returns: -# None -# """ - -# assert(node.type == 'function_definition') - -# entity = self.process_function_definition_node(node, path, source_code) -# if entity is not None: -# # Add Function object to the graph -# try: -# graph.add_function(entity) -# except Exception: -# logger.error(f"Failed creating function: {entity}") -# entity = None - -# if entity is not None: -# # Connect parent to entity -# graph.connect_entities('DEFINES', parent.id, entity.id) - -# def process_field_declaration(self, node: Node) -> Optional[tuple[str, str]]: -# """ -# Processes a field declaration node to extract field name and type. - -# Args: -# node (Node): The AST node representing a field declaration. - -# Returns: -# Optional[Tuple[str, str]]: A tuple containing the field name and type, or None if either could not be determined. -# """ - -# assert(node.type == 'field_declaration') - -# const = False -# field_name = None -# field_type = '' - -# for child in node.children: -# if child.type == 'field_identifier': -# field_name = child.text.decode('utf-8') -# elif child.type == 'type_qualifier': -# const = True -# elif child.type == 'struct_specifier': -# # TODO: handle nested structs -# # TODO: handle union -# pass -# elif child.type == 'primitive_type': -# field_type = child.text.decode('utf-8') -# elif child.type == 'sized_type_specifier': -# field_type = child.text.decode('utf-8') -# elif child.type == 'pointer_declarator': -# pointer_field_type, field_name = self.process_pointer_declaration(child) -# field_type += pointer_field_type -# elif child.type == 'array_declarator': -# field_type += '[]' -# field_name = child.children[0].text.decode('utf-8') -# else: -# continue - -# if field_type is not None and const is True: -# field_type = f'const {field_type}' - -# if field_name is not None and field_type is not None: -# return (field_name, field_type) -# else: -# return None - -# def process_struct_specifier_node(self, node: Node, path: Path) -> Optional[Struct]: -# """ -# Processes a struct specifier node to extract struct fields. - -# Args: -# node (Node): The AST node representing the struct specifier. -# path (Path): The file path where the struct is defined. - -# Returns: -# Optional[Struct]: A Struct object containing details about the struct, or None if the struct name or fields could not be determined. -# """ - -# # Do not process struct without a declaration_list -# res = find_child_of_type(node, 'field_declaration_list') -# if res is None: -# return None - -# field_declaration_list = res[0] - -# # Extract struct name -# res = find_child_of_type(node, 'type_identifier') -# if res is None: -# return None - -# type_identifier = res[0] -# struct_name = type_identifier.text.decode('utf-8') - -# start_line = node.start_point[0] -# end_line = node.end_point[0] -# s = Struct(str(path), struct_name, '', start_line, end_line) - -# # Collect struct fields -# for child in field_declaration_list.children: -# if child.type == 'field_declaration': -# res = self.process_field_declaration(child) -# if res is None: -# return None -# else: -# field_name, field_type = res -# s.add_field(field_name, field_type) - -# return s - -# def process_struct_specifier(self, parent: File, node: Node, path: Path, -# graph: Graph) -> Node: -# """ -# Processes a struct specifier node to extract struct details and adds it to the graph. - -# Args: -# parent (File): The parent File object. -# node (Node): The AST node representing the struct specifier. -# path (Path): The file path where the struct is defined. -# graph (Graph): The Graph object to which the struct entity will be added. - -# Returns: -# Optional[Node]: The processed AST node representing the struct specifier if successful, otherwise None. - -# Raises: -# AssertionError: If the provided node is not of type 'struct_specifier'. -# """ - -# assert(node.type == 'struct_specifier') - -# entity = self.process_struct_specifier_node(node, path) -# if entity is not None: -# # Add Struct object to the graph -# try: -# graph.add_struct(entity) -# except Exception: -# logger.warning(f"Failed creating struct: {entity}") -# entity = None - -# if entity is not None: -# # Connect parent to entity -# graph.connect_entities('DEFINES', parent.id, entity.id) - -# def add_symbols(self, entity: Entity) -> None: -# if entity.node.type == 'struct_specifier': -# superclasses = entity.node.child_by_field_name("superclasses") -# if superclasses: -# base_classes_query = self.language.query("(argument_list (_) @base_class)") -# base_classes_captures = base_classes_query.captures(superclasses) -# if 'base_class' in base_classes_captures: -# for base_class in base_classes_captures['base_class']: -# entity.add_symbol("base_class", base_class) -# elif entity.node.type == 'function_definition': -# query = self.language.query("(call) @reference.call") -# captures = query.captures(entity.node) -# if 'reference.call' in captures: -# for caller in captures['reference.call']: -# entity.add_symbol("call", caller) -# query = self.language.query("(typed_parameter type: (_) @parameter)") -# captures = query.captures(entity.node) -# if 'parameter' in captures: -# for parameter in captures['parameter']: -# entity.add_symbol("parameters", parameter) -# return_type = entity.node.child_by_field_name('return_type') -# if return_type: -# entity.add_symbol("return_type", return_type) - -# def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, path: Path, node: Node) -> list[Entity]: -# res = [] -# for file, resolved_node in self.resolve(files, lsp, path, node): -# type_dec = self.find_parent(resolved_node, ['struct_specifier']) -# res.append(file.entities[type_dec]) -# return res - -# def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, path: Path, node: Node) -> list[Entity]: -# res = [] -# for file, resolved_node in self.resolve(files, lsp, path, node): -# method_dec = self.find_parent(resolved_node, ['function_definition']) -# if not method_dec: -# continue -# if method_dec in file.entities: -# res.append(file.entities[method_dec]) -# return res - -# def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, path: Path, key: str, symbol: Node) -> Entity: -# if key in ["parameters", "return_type"]: -# return self.resolve_type(files, lsp, path, symbol) -# elif key in ["call"]: -# return self.resolve_method(files, lsp, path, symbol) -# else: -# raise ValueError(f"Unknown key {key}") - -# def second_pass(self, path: Path, graph: Graph) -> None: -# """ -# Perform the second pass processing of a C source file or header file to establish function call relationships. - -# Args: -# path (Path): The path to the C source file or header file. -# f (io.TextIOWrapper): The file object representing the opened C source file or header file. -# graph (Graph): The Graph object containing entities (functions and files) to establish relationships. - -# Returns: -# None - -# This function processes the specified C source file or header file to establish relationships between -# functions based on function calls. It performs the following steps: - -# - Checks if the file path ends with '.c' or '.h'. If not, logs a debug message and skips processing. -# - Retrieves the file entity (`file`) from the graph based on the file path. -# - Parses the content of the file using a parser instance (`self.parser`). If parsing fails, logs an error. -# - Uses Tree-sitter queries (`query_function_def` and `query_call_exp`) to locate function definitions and -# function invocations (calls) within the parsed AST (`tree.root_node`). -# - Iterates over captured function definitions (`function_defs`) and their corresponding function calls -# (`function_calls`). For each function call: -# - Retrieves or creates a function entity (`callee_f`) in the graph. -# - Connects the caller function (`caller_f`) to the callee function (`callee_f`) using a 'CALLS' edge in -# the graph. - -# Note: -# - This function assumes that function calls to native functions (e.g., 'printf') will create missing -# function entities (`Function` objects) and add them to the graph. - -# Example usage: -# ``` -# second_pass(Path('/path/to/file.c'), open('/path/to/file.c', 'r'), graph) -# ``` -# """ - -# if path.suffix != '.c' and path.suffix != '.h': -# logger.debug(f"Skipping none C file {path}") -# return - -# logger.info(f"Processing {path}") - -# # Get file entity -# file = graph.get_file(os.path.dirname(path), path.name, path.suffix) -# if file is None: -# logger.error(f"File entity not found for: {path}") -# return - -# try: -# # Parse file -# content = path.read_bytes() -# tree = self.parser.parse(content) -# except Exception as e: -# logger.error(f"Failed to process file {path}: {e}") -# return - -# # Locate function invocation -# query_call_exp = self.language.query("(call_expression function: (identifier) @callee)") - -# # Locate function definitions -# query_function_def = self.language.query(""" -# ( -# function_definition -# declarator: (function_declarator -# declarator: (identifier) @function_name) -# )""") - -# function_defs = query_function_def.captures(tree.root_node) -# for function_def in function_defs: -# caller = function_def[0] -# caller_name = caller.text.decode('utf-8') -# caller_f = graph.get_function_by_name(caller_name) -# assert(caller_f is not None) - -# function_calls = query_call_exp.captures(caller.parent.parent) -# for function_call in function_calls: -# callee = function_call[0] -# callee_name = callee.text.decode('utf-8') -# callee_f = graph.get_function_by_name(callee_name) - -# if callee_f is None: -# # Create missing function -# # Assuming this is a call to a native function e.g. 'printf' -# callee_f = Function('/', callee_name, None, None, None, 0, 0) -# graph.add_function(callee_f) - -# # Connect the caller and callee in the graph -# graph.connect_entities('CALLS', caller_f.id, callee_f.id) +from pathlib import Path +from typing import Optional + +from multilspy import SyncLanguageServer +from ...entities.entity import Entity +from ...entities.file import File +from ..analyzer import AbstractAnalyzer + +import tree_sitter_c as tsc +from tree_sitter import Language, Node + +import logging +logger = logging.getLogger('code_graph') + + +class CAnalyzer(AbstractAnalyzer): + def __init__(self) -> None: + super().__init__(Language(tsc.language())) + + def add_dependencies(self, path: Path, files: list[Path]): + pass + + def get_entity_label(self, node: Node) -> str: + if node.type == 'struct_specifier': + return "Struct" + elif node.type == 'function_definition': + return "Function" + raise ValueError(f"Unknown entity type: {node.type}") + + def get_entity_name(self, node: Node) -> str: + if node.type == 'struct_specifier': + name_node = node.child_by_field_name('name') + if name_node: + return name_node.text.decode('utf-8') + raise ValueError("Struct has no name") + elif node.type == 'function_definition': + declarator = node.child_by_field_name('declarator') + if declarator: + name_node = declarator.child_by_field_name('declarator') + if name_node: + return name_node.text.decode('utf-8') + raise ValueError("Function has no name") + raise ValueError(f"Unknown entity type: {node.type}") + + def get_entity_docstring(self, node: Node) -> Optional[str]: + if node.type in ['struct_specifier', 'function_definition']: + if node.prev_sibling and node.prev_sibling.type == 'comment': + return node.prev_sibling.text.decode('utf-8') + return None + raise ValueError(f"Unknown entity type: {node.type}") + + def get_entity_types(self) -> list[str]: + return ['struct_specifier', 'function_definition'] + + def add_symbols(self, entity: Entity) -> None: + if entity.node.type == 'function_definition': + # Find function calls + captures = self._captures("(call_expression function: (identifier) @reference.call)", entity.node) + if 'reference.call' in captures: + for caller in captures['reference.call']: + entity.add_symbol("call", caller) + + # Find parameters + captures = self._captures("(parameter_declaration type: (_) @parameter)", entity.node) + if 'parameter' in captures: + for parameter in captures['parameter']: + entity.add_symbol("parameters", parameter) + + # Return type + return_type = entity.node.child_by_field_name('type') + if return_type: + entity.add_symbol("return_type", return_type) + + def is_dependency(self, file_path: str) -> bool: + return False + + def resolve_path(self, file_path: str, path: Path) -> str: + return file_path + + def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: + res = [] + for file, resolved_node in self.resolve(files, lsp, file_path, path, node): + type_dec = self.find_parent(resolved_node, ['struct_specifier']) + if type_dec in file.entities: + res.append(file.entities[type_dec]) + return res + + def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: + res = [] + if node.type == 'call_expression': + func_node = node.child_by_field_name('function') + if func_node: + node = func_node + for file, resolved_node in self.resolve(files, lsp, file_path, path, node): + method_dec = self.find_parent(resolved_node, ['function_definition']) + if method_dec and method_dec in file.entities: + res.append(file.entities[method_dec]) + return res + + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: + if key in ["parameters", "return_type"]: + return self.resolve_type(files, lsp, file_path, path, symbol) + elif key in ["call"]: + return self.resolve_method(files, lsp, file_path, path, symbol) + else: + raise ValueError(f"Unknown key {key}") + + def get_include_paths(self, tree) -> list[str]: + """Extract #include paths from a parsed C file.""" + includes = [] + captures = self._captures( + "(preproc_include [(string_literal) (system_lib_string)] @include)", + tree.root_node + ) + if 'include' in captures: + for node in captures['include']: + path_text = node.text.decode('utf-8').strip('"<>') + if path_text: + includes.append(path_text) + return includes diff --git a/api/analyzers/source_analyzer.py b/api/analyzers/source_analyzer.py index 9046abcf..2545675f 100644 --- a/api/analyzers/source_analyzer.py +++ b/api/analyzers/source_analyzer.py @@ -7,7 +7,7 @@ from ..graph import Graph from .analyzer import AbstractAnalyzer -# from .c.analyzer import CAnalyzer +from .c.analyzer import CAnalyzer from .csharp.analyzer import CSharpAnalyzer from .java.analyzer import JavaAnalyzer from .javascript.analyzer import JavaScriptAnalyzer @@ -24,8 +24,8 @@ # List of available analyzers analyzers: dict[str, AbstractAnalyzer] = { - # '.c': CAnalyzer(), - # '.h': CAnalyzer(), + '.c': CAnalyzer(), + '.h': CAnalyzer(), '.py': PythonAnalyzer(), '.java': JavaAnalyzer(), '.cs': CSharpAnalyzer(), @@ -152,7 +152,10 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: lsps[".kt"] = NullLanguageServer() lsps[".kts"] = NullLanguageServer() lsps[".js"] = NullLanguageServer() - with lsps[".java"].start_server(), lsps[".py"].start_server(), lsps[".cs"].start_server(), lsps[".js"].start_server(), lsps[".kt"].start_server(), lsps[".kts"].start_server(): + # C doesn't have a multilspy language server + lsps[".c"] = NullLanguageServer() + lsps[".h"] = NullLanguageServer() + with lsps[".java"].start_server(), lsps[".py"].start_server(), lsps[".cs"].start_server(), lsps[".js"].start_server(), lsps[".kt"].start_server(), lsps[".kts"].start_server(), lsps[".c"].start_server(), lsps[".h"].start_server(): files_len = len(self.files) for i, file_path in enumerate(files): if file_path not in self.files: @@ -185,7 +188,7 @@ def analyze_files(self, files: list[Path], path: Path, graph: Graph) -> None: def analyze_sources(self, path: Path, ignore: list[str], graph: Graph) -> None: path = path.resolve() - files = list(path.rglob("*.java")) + list(path.rglob("*.py")) + list(path.rglob("*.cs")) + [f for f in path.rglob("*.js") if "node_modules" not in f.parts] + list(path.rglob("*.kt")) + list(path.rglob("*.kts")) + files = list(path.rglob("*.java")) + list(path.rglob("*.py")) + list(path.rglob("*.cs")) + [f for f in path.rglob("*.js") if "node_modules" not in f.parts] + list(path.rglob("*.kt")) + list(path.rglob("*.kts")) + list(path.rglob("*.c")) + list(path.rglob("*.h")) # First pass analysis of the source code self.first_pass(path, files, ignore, graph) diff --git a/tests/source_files/c/src.c b/tests/source_files/c/src.c index 22118abf..af7fe061 100644 --- a/tests/source_files/c/src.c +++ b/tests/source_files/c/src.c @@ -1,3 +1,7 @@ +#include "myheader.h" +#include + +/* Adds two integers */ int add ( int a, diff --git a/tests/test_c_analyzer.py b/tests/test_c_analyzer.py index ea8a16b7..edcfedb5 100644 --- a/tests/test_c_analyzer.py +++ b/tests/test_c_analyzer.py @@ -1,70 +1,141 @@ -import os import unittest from pathlib import Path -from api import SourceAnalyzer, Graph - - -class Test_C_Analyzer(unittest.TestCase): - def test_analyzer(self): - path = Path(__file__).parent - analyzer = SourceAnalyzer() - - # Get the current file path - current_file_path = os.path.abspath(__file__) - - # Get the directory of the current file - current_dir = os.path.dirname(current_file_path) - - # Append 'source_files/c' to the current directory - path = os.path.join(current_dir, 'source_files') - path = os.path.join(path, 'c') - path = str(path) - - g = Graph("c") - analyzer.analyze_local_folder(path, g) - - f = g.get_file('', 'src.c', '.c') - self.assertIsNotNone(f) - self.assertEqual(f.properties['name'], 'src.c') - self.assertEqual(f.properties['ext'], '.c') - - s = g.get_struct_by_name('exp') - self.assertIsNotNone(s) - self.assertEqual(s.properties['name'], 'exp') - self.assertEqual(s.properties['path'], 'src.c') - self.assertEqual(s.properties['src_start'], 9) - self.assertEqual(s.properties['src_end'], 13) - self.assertEqual(s.properties['fields'], [['i', 'int'], ['f', 'float'], ['data', 'char[]']]) - - add = g.get_function_by_name('add') - self.assertIsNotNone(add) - self.assertEqual(add.properties['name'], 'add') - self.assertEqual(add.properties['path'], 'src.c') - self.assertEqual(add.properties['ret_type'], 'int') - self.assertEqual(add.properties['src_start'], 0) - self.assertEqual(add.properties['src_end'], 7) - self.assertEqual(add.properties['args'], [['a', 'int'], ['b', 'int']]) - self.assertIn('a + b', add.properties['src']) - - main = g.get_function_by_name('main') - self.assertIsNotNone(main) - self.assertEqual(main.properties['name'], 'main') - self.assertEqual(main.properties['path'], 'src.c') - self.assertEqual(main.properties['ret_type'], 'int') - self.assertEqual(main.properties['src_start'], 15) - self.assertEqual(main.properties['src_end'], 18) - self.assertEqual(main.properties['args'], [['argv', 'const char**'], ['argc', 'int']]) - self.assertIn('x = add', main.properties['src']) - - callees = g.function_calls(main.id) - self.assertEqual(len(callees), 1) - self.assertEqual(callees[0], add) - - callers = g.function_called_by(add.id) - callers = [caller.properties['name'] for caller in callers] - - self.assertEqual(len(callers), 2) - self.assertIn('add', callers) - self.assertIn('main', callers) +from api.analyzers.c.analyzer import CAnalyzer +from api.entities.entity import Entity +from api.entities.file import File + +def _entity_name(analyzer, entity): + """Get the name of an entity using the analyzer.""" + return analyzer.get_entity_name(entity.node) + + +class TestCAnalyzer(unittest.TestCase): + """Test the C analyzer's entity extraction (no DB required).""" + + @classmethod + def setUpClass(cls): + cls.analyzer = CAnalyzer() + source_dir = Path(__file__).parent / "source_files" / "c" + cls.sample_path = source_dir / "src.c" + source = cls.sample_path.read_bytes() + tree = cls.analyzer.parser.parse(source) + cls.file = File(cls.sample_path, tree) + + # Walk AST and extract entities + types = cls.analyzer.get_entity_types() + stack = [tree.root_node] + while stack: + node = stack.pop() + if node.type in types: + entity = Entity(node) + cls.analyzer.add_symbols(entity) + cls.file.add_entity(entity) + stack.extend(node.children) + else: + stack.extend(node.children) + + # Extract includes + cls.includes = cls.analyzer.get_include_paths(tree) + + def _entity_names(self): + return [_entity_name(self.analyzer, e) for e in self.file.entities.values()] + + def test_entity_types(self): + """Analyzer should recognise C entity types.""" + self.assertEqual( + self.analyzer.get_entity_types(), + ['struct_specifier', 'function_definition'], + ) + + def test_function_extraction(self): + """Functions should be extracted from src.c.""" + names = self._entity_names() + self.assertIn("add", names) + self.assertIn("main", names) + + def test_struct_extraction(self): + """Structs should be extracted from src.c.""" + names = self._entity_names() + self.assertIn("exp", names) + + def test_function_label(self): + """Functions should get the 'Function' label.""" + for entity in self.file.entities.values(): + if _entity_name(self.analyzer, entity) == "add": + self.assertEqual(self.analyzer.get_entity_label(entity.node), "Function") + + def test_struct_label(self): + """Structs should get the 'Struct' label.""" + for entity in self.file.entities.values(): + if _entity_name(self.analyzer, entity) == "exp": + self.assertEqual(self.analyzer.get_entity_label(entity.node), "Struct") + + def test_call_symbols(self): + """Function 'main' should have call symbols (calls to 'add').""" + for entity in self.file.entities.values(): + if _entity_name(self.analyzer, entity) == "main": + call_syms = entity.symbols.get("call", []) + self.assertTrue(len(call_syms) > 0, "main should have call symbols") + + def test_include_extraction(self): + """Include directives should be extracted.""" + self.assertIn("myheader.h", self.includes) + self.assertIn("stdio.h", self.includes) + + def test_is_dependency(self): + """is_dependency should return False for C files.""" + self.assertFalse(self.analyzer.is_dependency("src/main.c")) + + def test_docstring_extraction(self): + """Docstring (comment above entity) should be extracted.""" + for entity in self.file.entities.values(): + if _entity_name(self.analyzer, entity) == "add": + doc = self.analyzer.get_entity_docstring(entity.node) + self.assertIsNotNone(doc) + self.assertIn("Adds two integers", doc) + return + self.fail("Function 'add' not found") + + def test_no_docstring(self): + """Entities without a preceding comment should return None.""" + for entity in self.file.entities.values(): + if _entity_name(self.analyzer, entity) == "main": + doc = self.analyzer.get_entity_docstring(entity.node) + self.assertIsNone(doc) + return + self.fail("Function 'main' not found") + + def test_unknown_entity_label_raises(self): + """get_entity_label should raise for unknown node types.""" + # Use the tree root node which is 'translation_unit', not a known entity + with self.assertRaises(ValueError): + self.analyzer.get_entity_label(self.file.tree.root_node) + + def test_unknown_entity_name_raises(self): + """get_entity_name should raise for unknown node types.""" + with self.assertRaises(ValueError): + self.analyzer.get_entity_name(self.file.tree.root_node) + + def test_resolve_path(self): + """resolve_path should return the file path unchanged.""" + self.assertEqual( + self.analyzer.resolve_path("/foo/bar.c", Path("/root")), + "/foo/bar.c", + ) + + def test_include_extraction_empty(self): + """A file with no #include directives should return an empty list.""" + source = b"int main() { return 0; }" + tree = self.analyzer.parser.parse(source) + self.assertEqual(self.analyzer.get_include_paths(tree), []) + + def test_resolve_symbol_unknown_key_raises(self): + """resolve_symbol should raise ValueError for unknown keys.""" + with self.assertRaises(ValueError): + self.analyzer.resolve_symbol({}, None, Path("x.c"), Path("."), "unknown_key", None) + + +if __name__ == "__main__": + unittest.main()