Skip to content

Commit 418a640

Browse files
author
TheDevConnor
committed
Made it so that you can call the struct insides its scope as wellq
1 parent 123fb7d commit 418a640

File tree

6 files changed

+357
-296
lines changed

6 files changed

+357
-296
lines changed

src/llvm/struct.c

Lines changed: 102 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,22 @@ LLVMValueRef codegen_stmt_struct(CodeGenContext *ctx, AstNode *node) {
4949
size_t data_field_count = 0;
5050
for (size_t i = 0; i < public_count; i++) {
5151
AstNode *member = node->stmt.struct_decl.public_members[i];
52-
if (member->type == AST_STMT_FIELD_DECL && !member->stmt.field_decl.function) {
52+
if (member->type == AST_STMT_FIELD_DECL &&
53+
!member->stmt.field_decl.function) {
5354
data_field_count++;
5455
}
5556
}
5657
for (size_t i = 0; i < private_count; i++) {
5758
AstNode *member = node->stmt.struct_decl.private_members[i];
58-
if (member->type == AST_STMT_FIELD_DECL && !member->stmt.field_decl.function) {
59+
if (member->type == AST_STMT_FIELD_DECL &&
60+
!member->stmt.field_decl.function) {
5961
data_field_count++;
6062
}
6163
}
6264

6365
if (data_field_count == 0) {
64-
fprintf(stderr, "Error: Struct %s must have at least one data field\n", struct_name);
66+
fprintf(stderr, "Error: Struct %s must have at least one data field\n",
67+
struct_name);
6568
return NULL;
6669
}
6770

@@ -89,14 +92,24 @@ LLVMValueRef codegen_stmt_struct(CodeGenContext *ctx, AstNode *node) {
8992
struct_info->field_is_public = (bool *)arena_alloc(
9093
ctx->arena, sizeof(bool) * data_field_count, alignof(bool));
9194

95+
// CRITICAL FIX: Create an OPAQUE struct type FIRST (forward declaration)
96+
// This allows self-referential structs like: struct Node { next: *Node; }
97+
struct_info->llvm_type = LLVMStructCreateNamed(ctx->context, struct_name);
98+
99+
// Add to context IMMEDIATELY so it can be found during field type resolution
100+
add_struct_type(ctx, struct_info);
101+
add_symbol(ctx, struct_name, NULL, struct_info->llvm_type, false);
102+
92103
// Process public data fields
93104
size_t field_index = 0;
94105
for (size_t i = 0; i < public_count; i++) {
95106
AstNode *member = node->stmt.struct_decl.public_members[i];
96-
if (member->type != AST_STMT_FIELD_DECL) continue;
97-
107+
if (member->type != AST_STMT_FIELD_DECL)
108+
continue;
109+
98110
// Skip methods for now, we'll process them after the struct type is created
99-
if (member->stmt.field_decl.function) continue;
111+
if (member->stmt.field_decl.function)
112+
continue;
100113

101114
const char *field_name = member->stmt.field_decl.name;
102115

@@ -109,13 +122,17 @@ LLVMValueRef codegen_stmt_struct(CodeGenContext *ctx, AstNode *node) {
109122
}
110123
}
111124

112-
struct_info->field_names[field_index] = arena_strdup(ctx->arena, field_name);
113-
struct_info->field_types[field_index] = codegen_type(ctx, member->stmt.field_decl.type);
114-
struct_info->field_element_types[field_index] = extract_element_type_from_ast(ctx, member->stmt.field_decl.type);
125+
struct_info->field_names[field_index] =
126+
arena_strdup(ctx->arena, field_name);
127+
struct_info->field_types[field_index] =
128+
codegen_type(ctx, member->stmt.field_decl.type);
129+
struct_info->field_element_types[field_index] =
130+
extract_element_type_from_ast(ctx, member->stmt.field_decl.type);
115131
struct_info->field_is_public[field_index] = true;
116132

117133
if (!struct_info->field_types[field_index]) {
118-
fprintf(stderr, "Error: Failed to resolve type for field %s in struct %s\n",
134+
fprintf(stderr,
135+
"Error: Failed to resolve type for field %s in struct %s\n",
119136
field_name, struct_name);
120137
return NULL;
121138
}
@@ -125,9 +142,11 @@ LLVMValueRef codegen_stmt_struct(CodeGenContext *ctx, AstNode *node) {
125142
// Process private data fields
126143
for (size_t i = 0; i < private_count; i++) {
127144
AstNode *member = node->stmt.struct_decl.private_members[i];
128-
if (member->type != AST_STMT_FIELD_DECL) continue;
129-
130-
if (member->stmt.field_decl.function) continue;
145+
if (member->type != AST_STMT_FIELD_DECL)
146+
continue;
147+
148+
if (member->stmt.field_decl.function)
149+
continue;
131150

132151
const char *field_name = member->stmt.field_decl.name;
133152

@@ -139,40 +158,45 @@ LLVMValueRef codegen_stmt_struct(CodeGenContext *ctx, AstNode *node) {
139158
}
140159
}
141160

142-
struct_info->field_names[field_index] = arena_strdup(ctx->arena, field_name);
143-
struct_info->field_types[field_index] = codegen_type(ctx, member->stmt.field_decl.type);
144-
struct_info->field_element_types[field_index] = extract_element_type_from_ast(ctx, member->stmt.field_decl.type);
145-
struct_info->field_is_public[field_index] = member->stmt.field_decl.is_public;
161+
struct_info->field_names[field_index] =
162+
arena_strdup(ctx->arena, field_name);
163+
struct_info->field_types[field_index] =
164+
codegen_type(ctx, member->stmt.field_decl.type);
165+
struct_info->field_element_types[field_index] =
166+
extract_element_type_from_ast(ctx, member->stmt.field_decl.type);
167+
struct_info->field_is_public[field_index] =
168+
member->stmt.field_decl.is_public;
146169

147170
if (!struct_info->field_types[field_index]) {
148-
fprintf(stderr, "Error: Failed to resolve type for field %s in struct %s\n",
171+
fprintf(stderr,
172+
"Error: Failed to resolve type for field %s in struct %s\n",
149173
field_name, struct_name);
150174
return NULL;
151175
}
152176
field_index++;
153177
}
154178

155-
// Create LLVM struct type
156-
struct_info->llvm_type = LLVMStructTypeInContext(
157-
ctx->context, struct_info->field_types, data_field_count, false);
158-
159-
// Add to context BEFORE processing methods
160-
add_struct_type(ctx, struct_info);
161-
add_symbol(ctx, struct_name, NULL, struct_info->llvm_type, false);
179+
// CRITICAL: Set the struct body AFTER all field types are resolved
180+
// This completes the opaque struct declaration with its actual fields
181+
LLVMStructSetBody(struct_info->llvm_type, struct_info->field_types,
182+
data_field_count, false);
162183

163-
// NOW process methods with access to the struct type
184+
// NOW process methods with access to the complete struct type
164185
for (size_t i = 0; i < public_count; i++) {
165186
AstNode *member = node->stmt.struct_decl.public_members[i];
166-
if (member->type != AST_STMT_FIELD_DECL) continue;
167-
187+
if (member->type != AST_STMT_FIELD_DECL)
188+
continue;
189+
168190
// Only process methods
169-
if (!member->stmt.field_decl.function) continue;
191+
if (!member->stmt.field_decl.function)
192+
continue;
170193

171194
AstNode *func_node = member->stmt.field_decl.function;
172195
const char *method_name = member->stmt.field_decl.name;
173196

174197
// Generate the method with implicit 'self' parameter
175-
if (!codegen_struct_method(ctx, func_node, struct_info, method_name, true)) {
198+
if (!codegen_struct_method(ctx, func_node, struct_info, method_name,
199+
true)) {
176200
fprintf(stderr, "Error: Failed to generate method '%s' for struct '%s'\n",
177201
method_name, struct_name);
178202
return NULL;
@@ -182,15 +206,19 @@ LLVMValueRef codegen_stmt_struct(CodeGenContext *ctx, AstNode *node) {
182206
// Process private methods
183207
for (size_t i = 0; i < private_count; i++) {
184208
AstNode *member = node->stmt.struct_decl.private_members[i];
185-
if (member->type != AST_STMT_FIELD_DECL) continue;
186-
187-
if (!member->stmt.field_decl.function) continue;
209+
if (member->type != AST_STMT_FIELD_DECL)
210+
continue;
211+
212+
if (!member->stmt.field_decl.function)
213+
continue;
188214

189215
AstNode *func_node = member->stmt.field_decl.function;
190216
const char *method_name = member->stmt.field_decl.name;
191217

192-
if (!codegen_struct_method(ctx, func_node, struct_info, method_name, false)) {
193-
fprintf(stderr, "Error: Failed to generate private method '%s' for struct '%s'\n",
218+
if (!codegen_struct_method(ctx, func_node, struct_info, method_name,
219+
false)) {
220+
fprintf(stderr,
221+
"Error: Failed to generate private method '%s' for struct '%s'\n",
194222
method_name, struct_name);
195223
return NULL;
196224
}
@@ -199,11 +227,12 @@ LLVMValueRef codegen_stmt_struct(CodeGenContext *ctx, AstNode *node) {
199227
return NULL;
200228
}
201229

202-
LLVMValueRef codegen_struct_method(CodeGenContext *ctx, AstNode *func_node,
203-
StructInfo *struct_info, const char *method_name,
204-
bool is_public) {
230+
LLVMValueRef codegen_struct_method(CodeGenContext *ctx, AstNode *func_node,
231+
StructInfo *struct_info,
232+
const char *method_name, bool is_public) {
205233
if (!func_node || func_node->type != AST_STMT_FUNCTION) {
206-
fprintf(stderr, "Error: Invalid function node for method '%s'\n", method_name);
234+
fprintf(stderr, "Error: Invalid function node for method '%s'\n",
235+
method_name);
207236
return NULL;
208237
}
209238

@@ -214,29 +243,32 @@ LLVMValueRef codegen_struct_method(CodeGenContext *ctx, AstNode *func_node,
214243
char **original_param_names = func_node->stmt.func_decl.param_names;
215244

216245
// CRITICAL: Methods need an implicit 'self' parameter as the FIRST parameter
217-
// The typechecker injects 'self' when calling methods, so the method definition must match
246+
// The typechecker injects 'self' when calling methods, so the method
247+
// definition must match
218248
size_t param_count = original_param_count + 1; // +1 for 'self'
219-
249+
220250
// Allocate arrays for ALL parameters (including self)
221251
LLVMTypeRef *llvm_param_types = (LLVMTypeRef *)arena_alloc(
222252
ctx->arena, sizeof(LLVMTypeRef) * param_count, alignof(LLVMTypeRef));
223-
253+
224254
char **param_names = (char **)arena_alloc(
225255
ctx->arena, sizeof(char *) * param_count, alignof(char *));
226-
256+
227257
AstNode **param_type_nodes = (AstNode **)arena_alloc(
228258
ctx->arena, sizeof(AstNode *) * param_count, alignof(AstNode *));
229259

230260
// First parameter is 'self' - a pointer to the struct
231261
llvm_param_types[0] = LLVMPointerType(struct_info->llvm_type, 0);
232262
param_names[0] = "self";
233-
param_type_nodes[0] = NULL; // We'll handle this specially for element type extraction
263+
param_type_nodes[0] =
264+
NULL; // We'll handle this specially for element type extraction
234265

235266
// Copy the rest of the original parameters (shifted by 1)
236267
for (size_t i = 0; i < original_param_count; i++) {
237268
llvm_param_types[i + 1] = codegen_type(ctx, original_param_type_nodes[i]);
238269
if (!llvm_param_types[i + 1]) {
239-
fprintf(stderr, "Error: Failed to resolve parameter type %zu for method '%s'\n",
270+
fprintf(stderr,
271+
"Error: Failed to resolve parameter type %zu for method '%s'\n",
240272
i, method_name);
241273
return NULL;
242274
}
@@ -247,22 +279,25 @@ LLVMValueRef codegen_struct_method(CodeGenContext *ctx, AstNode *func_node,
247279
// Create function type
248280
LLVMTypeRef llvm_return_type = codegen_type(ctx, return_type_node);
249281
if (!llvm_return_type) {
250-
fprintf(stderr, "Error: Failed to resolve return type for method '%s'\n", method_name);
282+
fprintf(stderr, "Error: Failed to resolve return type for method '%s'\n",
283+
method_name);
251284
return NULL;
252285
}
253286

254-
LLVMTypeRef func_type = LLVMFunctionType(llvm_return_type, llvm_param_types,
255-
param_count, 0);
287+
LLVMTypeRef func_type =
288+
LLVMFunctionType(llvm_return_type, llvm_param_types, param_count, 0);
256289

257290
// Get the current LLVM module
258291
LLVMModuleRef current_llvm_module =
259292
ctx->current_module ? ctx->current_module->module : ctx->module;
260293

261294
// Create the function in the current module
262-
LLVMValueRef func = LLVMAddFunction(current_llvm_module, method_name, func_type);
263-
295+
LLVMValueRef func =
296+
LLVMAddFunction(current_llvm_module, method_name, func_type);
297+
264298
if (!func) {
265-
fprintf(stderr, "Error: Failed to create LLVM function for method '%s'\n", method_name);
299+
fprintf(stderr, "Error: Failed to create LLVM function for method '%s'\n",
300+
method_name);
266301
return NULL;
267302
}
268303

@@ -275,31 +310,34 @@ LLVMValueRef codegen_struct_method(CodeGenContext *ctx, AstNode *func_node,
275310

276311
// CRITICAL: Save the old function context before starting method generation
277312
LLVMValueRef old_function = ctx->current_function;
278-
313+
279314
// Set current function context
280315
ctx->current_function = func;
281316

282317
// Create entry basic block
283-
LLVMBasicBlockRef entry = LLVMAppendBasicBlockInContext(
284-
ctx->context, func, "entry");
318+
LLVMBasicBlockRef entry =
319+
LLVMAppendBasicBlockInContext(ctx->context, func, "entry");
285320
LLVMPositionBuilderAtEnd(ctx->builder, entry);
286321

287322
// Add all parameters to symbol table (including self at index 0)
288323
for (size_t i = 0; i < param_count; i++) {
289324
LLVMValueRef param = LLVMGetParam(func, i);
290325
const char *param_name = param_names[i];
291-
326+
292327
LLVMSetValueName2(param, param_name, strlen(param_name));
293-
328+
294329
// Allocate stack space and store parameter
295-
LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, llvm_param_types[i], param_name);
330+
LLVMValueRef alloca =
331+
LLVMBuildAlloca(ctx->builder, llvm_param_types[i], param_name);
296332
LLVMBuildStore(ctx->builder, param, alloca);
297-
298-
// Extract element type for pointer parameters (needed for self which is *Person)
299-
LLVMTypeRef element_type = extract_element_type_from_ast(ctx, param_type_nodes[i]);
300-
333+
334+
// Extract element type for pointer parameters (needed for self which is
335+
// *Person)
336+
LLVMTypeRef element_type =
337+
extract_element_type_from_ast(ctx, param_type_nodes[i]);
338+
301339
// Add to symbol table with element type information
302-
add_symbol_with_element_type(ctx, param_name, alloca, llvm_param_types[i],
340+
add_symbol_with_element_type(ctx, param_name, alloca, llvm_param_types[i],
303341
element_type, false);
304342
}
305343

@@ -318,7 +356,8 @@ LLVMValueRef codegen_struct_method(CodeGenContext *ctx, AstNode *func_node,
318356

319357
// Verify the function
320358
if (LLVMVerifyFunction(func, LLVMReturnStatusAction)) {
321-
fprintf(stderr, "Error: Function verification failed for method '%s'\n", method_name);
359+
fprintf(stderr, "Error: Function verification failed for method '%s'\n",
360+
method_name);
322361
LLVMDumpValue(func);
323362
// Restore context even on error
324363
ctx->current_function = old_function;

0 commit comments

Comments
 (0)