Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions main.c
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,31 @@ static void print_mmu_cache_stats(vm_t *vm)
fprintf(stderr, "\n=== MMU Cache Statistics ===\n");
for (uint32_t i = 0; i < vm->n_hart; i++) {
hart_t *hart = vm->hart[i];
uint64_t fetch_total =
hart->cache_fetch.hits + hart->cache_fetch.misses;

/* Combine 2-entry tlb statistics */
uint64_t fetch_hits_tlb = 0, fetch_misses_tlb = 0;
fetch_hits_tlb =
hart->cache_fetch[0].tlb_hits + hart->cache_fetch[1].tlb_hits;
fetch_misses_tlb =
hart->cache_fetch[0].tlb_misses + hart->cache_fetch[1].tlb_misses;
uint64_t tlb_total = fetch_hits_tlb + fetch_misses_tlb;

/* Combine icache statistics */
uint64_t fetch_hits_icache = 0, fetch_misses_icache = 0;
fetch_hits_icache =
hart->cache_fetch[0].icache_hits + hart->cache_fetch[1].icache_hits;
fetch_misses_icache = hart->cache_fetch[0].icache_misses +
hart->cache_fetch[1].icache_misses;

/* Combine victim cache statistics */
uint64_t fetch_hits_vcache = 0, fetch_misses_vcache = 0;
fetch_hits_vcache =
hart->cache_fetch[0].vcache_hits + hart->cache_fetch[1].vcache_hits;
fetch_misses_vcache = hart->cache_fetch[0].vcache_misses +
hart->cache_fetch[1].vcache_misses;

uint64_t access_total =
hart->cache_fetch[0].total_fetch + hart->cache_fetch[1].total_fetch;

/* Combine 8-set × 2-way load cache statistics */
uint64_t load_hits = 0, load_misses = 0;
Expand All @@ -1047,14 +1070,37 @@ static void print_mmu_cache_stats(vm_t *vm)
}
uint64_t store_total = store_hits + store_misses;

fprintf(stderr, "\nHart %u:\n", i);
fprintf(stderr, " Fetch: %12llu hits, %12llu misses",
hart->cache_fetch.hits, hart->cache_fetch.misses);
if (fetch_total > 0)
fprintf(stderr, " (%.2f%% hit rate)",
100.0 * hart->cache_fetch.hits / fetch_total);
fprintf(stderr, "\n");

fprintf(stderr, "\n=== Introduction Cache Statistics ===\n");
fprintf(stderr, " Total access: %12llu\n", access_total);
if (access_total > 0) {
fprintf(stderr, " Icache hits: %12llu (%.2f%%)\n",
fetch_hits_icache,
(fetch_hits_icache * 100.0) / access_total);

fprintf(stderr, " Icache misses: %12llu (%.2f%%)\n",
fetch_misses_icache,
(fetch_misses_icache * 100.0) / access_total);
}
if (access_total > 0 && fetch_misses_icache > 0) {
fprintf(stderr,
" ├ Vcache hits: %8llu (%.2f%% of Icache misses)\n",
fetch_hits_vcache,
(fetch_hits_vcache * 100.0) / fetch_misses_icache,
(fetch_hits_vcache * 100.0) / access_total);
fprintf(stderr,
" └ Vcache misses: %8llu (%.2f%% of Icache misses)\n",
fetch_misses_vcache,
(fetch_misses_vcache * 100.0) / fetch_misses_icache,
(fetch_misses_vcache * 100.0) / access_total);
}
if (tlb_total > 0) {
fprintf(stderr, " ├ TLB hits: %4llu (%.2f%%)\n",
fetch_hits_tlb, (fetch_hits_tlb * 100.0) / (tlb_total));
fprintf(stderr, " └ TLB misses: %4llu (%.2f%%)\n",
fetch_misses_tlb, (fetch_misses_tlb * 100.0) / (tlb_total));
}
fprintf(stderr, "\n=== Data Cache Statistics ===\n");
fprintf(stderr, " Load: %12llu hits, %12llu misses (8x2)", load_hits,
load_misses);
if (load_total > 0)
Expand Down
112 changes: 100 additions & 12 deletions riscv.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <stdio.h>
#include <string.h>

#include "common.h"
#include "device.h"
Expand Down Expand Up @@ -180,11 +181,17 @@ static inline uint32_t read_rs2(const hart_t *vm, uint32_t insn)
return vm->x_regs[decode_rs2(insn)];
}

static inline void icache_invalidate_all(hart_t *vm)
{
memset(&vm->icache, 0, sizeof(vm->icache));
}

/* virtual addressing */

void mmu_invalidate(hart_t *vm)
{
vm->cache_fetch.n_pages = 0xFFFFFFFF;
vm->cache_fetch[0].n_pages = 0xFFFFFFFF;
vm->cache_fetch[1].n_pages = 0xFFFFFFFF;
/* Invalidate all 8 sets × 2 ways for load cache */
for (int set = 0; set < 8; set++) {
for (int way = 0; way < 2; way++)
Expand All @@ -197,6 +204,7 @@ void mmu_invalidate(hart_t *vm)
vm->cache_store[set].ways[way].n_pages = 0xFFFFFFFF;
vm->cache_store[set].lru = 0; /* Reset LRU to way 0 */
}
icache_invalidate_all(vm);
}

/* Invalidate MMU caches for a specific virtual address range.
Expand Down Expand Up @@ -227,9 +235,11 @@ void mmu_invalidate_range(hart_t *vm, uint32_t start_addr, uint32_t size)
uint32_t end_vpn = (uint32_t) end_addr >> RV_PAGE_SHIFT;

/* Cache invalidation for fetch cache */
if (vm->cache_fetch.n_pages >= start_vpn &&
vm->cache_fetch.n_pages <= end_vpn)
vm->cache_fetch.n_pages = 0xFFFFFFFF;
for (int i = 0; i < 2; i++) {
Copy link

@cubic-dev-ai cubic-dev-ai bot Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmu_invalidate_range now needs to flush the new instruction cache as well; otherwise pages whose mappings change keep returning stale instructions from vm->icache.

Prompt for AI agents
Address the following comment on riscv.c at line 238:

<comment>mmu_invalidate_range now needs to flush the new instruction cache as well; otherwise pages whose mappings change keep returning stale instructions from vm-&gt;icache.</comment>

<file context>
@@ -227,9 +235,11 @@ void mmu_invalidate_range(hart_t *vm, uint32_t start_addr, uint32_t size)
-    if (vm-&gt;cache_fetch.n_pages &gt;= start_vpn &amp;&amp;
-        vm-&gt;cache_fetch.n_pages &lt;= end_vpn)
-        vm-&gt;cache_fetch.n_pages = 0xFFFFFFFF;
+    for (int i = 0; i &lt; 2; i++) {
+        if (vm-&gt;cache_fetch[i].n_pages &gt;= start_vpn &amp;&amp;
+            vm-&gt;cache_fetch[i].n_pages &lt;= end_vpn)
</file context>
Fix with Cubic

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m a bit confused about this.
The current implementation is as follows:

/* Cache invalidation for fetch cache */
for (int i = 0; i < 2; i++) {
    if (vm->cache_fetch[i].n_pages >= start_vpn &&
        vm->cache_fetch[i].n_pages <= end_vpn)
        vm->cache_fetch[i].n_pages = 0xFFFFFFFF;
}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yy214123 You're right to be confused—the code snippet you showed only handles invalidating the page translation caches (cache_fetch), but the comment was pointing out a bigger gap: the actual instruction cache blocks (vm->icache.i_block and vm->icache.v_block) aren't being touched at all.

Right now, when you call mmu_invalidate_range, it only invalidates the 2-entry page cache entries that fall in your VPN range. But if the instruction cache already holds translated/cached instruction blocks from those pages (which it does after fetches), they'll keep serving stale data even after the page mappings change—until those specific cache lines get evicted naturally by future misses.

To fix it properly, you need to add logic in mmu_invalidate_range to scan the I-cache and victim cache for blocks whose tags correspond to the invalidated address range, and mark those valid = false. Something like:

/* Invalidate I-cache blocks in the address range */
uint32_t start_block_addr = (start_addr >> ICACHE_OFFSET_BITS) << ICACHE_OFFSET_BITS;
uint32_t end_block_addr = ((end_addr + ICACHE_BLOCK_SIZE - 1) >> ICACHE_OFFSET_BITS) << ICACHE_OFFSET_BITS;

for (uint32_t block_idx = 0; block_idx < ICACHE_BLOCKS; block_idx++) {
    uint32_t block_base = (block_idx << ICACHE_OFFSET_BITS) | (vm->icache.i_block[block_idx].tag << (ICACHE_OFFSET_BITS + ICACHE_INDEX_BITS));
    if (block_base >= start_block_addr && block_base <= end_block_addr && vm->icache.i_block[block_idx].valid) {
        vm->icache.i_block[block_idx].valid = false;
    }
}

/* Similar loop for victim cache - check if v_block[i].tag matches any address in range */
for (int i = 0; i < VCACHE_BLOCKS; i++) {
    if (vm->icache.v_block[i].valid) {
        uint32_t v_block_addr = vm->icache.v_block[i].tag << ICACHE_OFFSET_BITS;
        if (v_block_addr >= start_block_addr && v_block_addr <= end_block_addr) {
            vm->icache.v_block[i].valid = false;
        }
    }
}

This is a bit more involved than the full icache_invalidate_all (which you already call in mmu_invalidate), but it's necessary for range-based invalidation to work correctly without flushing the entire cache unnecessarily. The tag math might need tweaking based on how you're storing virtual vs. physical addresses in the blocks—double-check that the reconstructed address comparison works as expected.

if (vm->cache_fetch[i].n_pages >= start_vpn &&
vm->cache_fetch[i].n_pages <= end_vpn)
vm->cache_fetch[i].n_pages = 0xFFFFFFFF;
}

/* Invalidate load cache: 8 sets × 2 ways */
for (int set = 0; set < 8; set++) {
Expand Down Expand Up @@ -361,10 +371,62 @@ static void mmu_fence(hart_t *vm, uint32_t insn UNUSED)

static void mmu_fetch(hart_t *vm, uint32_t addr, uint32_t *value)
{
uint32_t idx = (addr >> ICACHE_OFFSET_BITS) & ICACHE_INDEX_MASK;
uint32_t tag = addr >> (ICACHE_OFFSET_BITS + ICACHE_INDEX_BITS);
icache_block_t *blk = &vm->icache.i_block[idx];
uint32_t vpn = addr >> RV_PAGE_SHIFT;
if (unlikely(vpn != vm->cache_fetch.n_pages)) {
uint32_t index = __builtin_parity(vpn) & 0x1;

#ifdef MMU_CACHE_STATS
vm->cache_fetch[index].total_fetch++;
#endif

/* icache lookup */
if (likely(blk->valid && blk->tag == tag)) {
#ifdef MMU_CACHE_STATS
vm->cache_fetch[index].icache_hits++;
#endif
uint32_t ofs = addr & ICACHE_BLOCK_MASK;
*value = *(const uint32_t *) (blk->base + ofs);
return;
}

/* icache miss, try victim cache */
#ifdef MMU_CACHE_STATS
vm->cache_fetch[index].icache_misses++;
#endif

uint32_t vcache_key = addr >> ICACHE_OFFSET_BITS;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fragile tag calculation:

  • Works by accident, obscures intent
  • Should explicitly store block_id instead of reconstructing

for (int i = 0; i < VCACHE_BLOCKS; i++) {
victim_cache_block_t *vblk = &vm->icache.v_block[i];

if (vblk->valid && vblk->tag == vcache_key) {
/* victim cache hit, swap blocks */
#ifdef MMU_CACHE_STATS
vm->cache_fetch.misses++;
vm->cache_fetch[index].vcache_hits++;
#endif
icache_block_t tmp = *blk;
*blk = *vblk;
*vblk = tmp;
blk->tag = tag;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code looks suspicious to me.

When you move the evicted I-cache block (tmp) back into the victim cache, you are setting the vblk->tag to tmp.tag, which is the 16-bit I-cache tag.

Won't this corrupts the victim cache entry? The VC search logic requires a 24-bit tag ([ICache Tag | ICache Index]) to function. Because you're only storing the 16-bit tag, this VCache entry will never be hit again.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this corrupts the victim cache entry? The VC search logic requires a 24-bit tag ([ICache Tag | ICache Index]) to function. Because you're only storing the 16-bit tag, this VCache entry will never be hit again.

Thank you for pointing that out. I’ve added the following expressions to ensure correctness :

+   vblk->tag = (tmp.tag << ICACHE_INDEX_BITS) | idx;

vblk->tag = (tmp.tag << ICACHE_INDEX_BITS) | idx;
vm->icache.v_used[i] = vm->instret;

uint32_t ofs = addr & ICACHE_BLOCK_MASK;
*value = *(const uint32_t *) (blk->base + ofs);
return;
}
}

#ifdef MMU_CACHE_STATS
vm->cache_fetch[index].vcache_misses++;
#endif

/* TLB lookup */
if (unlikely(vpn != vm->cache_fetch[index].n_pages)) {
/*TLB miss: need to translate VA to PA*/
#ifdef MMU_CACHE_STATS
vm->cache_fetch[index].tlb_misses++;
#endif
mmu_translate(vm, &addr, (1 << 3), (1 << 6), false, RV_EXC_FETCH_FAULT,
RV_EXC_FETCH_PFAULT);
Expand All @@ -374,15 +436,41 @@ static void mmu_fetch(hart_t *vm, uint32_t addr, uint32_t *value)
vm->mem_fetch(vm, addr >> RV_PAGE_SHIFT, &page_addr);
if (vm->error)
return;
vm->cache_fetch.n_pages = vpn;
vm->cache_fetch.page_addr = page_addr;
vm->cache_fetch[index].n_pages = vpn;
vm->cache_fetch[index].page_addr = page_addr;
}
#ifdef MMU_CACHE_STATS
/*TLB hit*/
else {
vm->cache_fetch.hits++;
}
#ifdef MMU_CACHE_STATS
vm->cache_fetch[index].tlb_hits++;
#endif
*value = vm->cache_fetch.page_addr[(addr >> 2) & MASK(RV_PAGE_SHIFT - 2)];
}

*value =
vm->cache_fetch[index].page_addr[(addr >> 2) & MASK(RV_PAGE_SHIFT - 2)];

/* Move the current icache block into the victim cache before replacement */
if (blk->valid) {
uint32_t lru_min = vm->icache.v_used[0];
int lru_index = 0;
for (int i = 1; i < VCACHE_BLOCKS; i++) {
if (vm->icache.v_used[i] < lru_min) {
lru_min = vm->icache.v_used[i];
lru_index = i;
}
}
Comment on lines +454 to +461
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

O(N) Critical Path: Linear LRU scan on every eviction is unacceptable

victim_cache_block_t *vblk = &vm->icache.v_block[lru_index];
*vblk = *blk;
vblk->tag = (blk->tag << ICACHE_INDEX_BITS) | idx;
vblk->valid = true;
vm->icache.v_used[lru_index] = vm->instret;
}

/* fill into the icache */
uint32_t block_off = (addr & RV_PAGE_MASK) & ~ICACHE_BLOCK_MASK;
blk->base = (const uint8_t *) vm->cache_fetch[index].page_addr + block_off;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pointer aliasing time bomb:

  • I-cache stores blk->base pointing to physical memory
  • TLB stores page_addr pointing to same memory
  • These can become desynchronized on page remapping → potential use-after-free

blk->tag = tag;
blk->valid = true;
}

static void mmu_load(hart_t *vm,
Expand Down
63 changes: 60 additions & 3 deletions riscv.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@ typedef struct {
uint32_t n_pages;
uint32_t *page_addr;
#ifdef MMU_CACHE_STATS
uint64_t hits;
uint64_t misses;
uint64_t total_fetch;
uint64_t tlb_hits;
uint64_t tlb_misses;
uint64_t icache_hits;
uint64_t icache_misses;
uint64_t vcache_hits;
uint64_t vcache_misses;
#endif
} mmu_fetch_cache_t;

Expand Down Expand Up @@ -75,7 +80,58 @@ typedef struct {
typedef struct __hart_internal hart_t;
typedef struct __vm_internel vm_t;

/* ICACHE_BLOCKS_SIZE: Size of one instruction-cache block (line).
* ICACHE_BLOCKS: Number of blocks (lines) in the instruction cache.
*
* The cache address is decomposed into [ tag | index | offset ] fields:
* - block-offset bits = log2(ICACHE_BLOCKS_SIZE)
* - index bits = log2(ICACHE_BLOCKS)
*/
#define ICACHE_BLOCKS_SIZE 256
#define ICACHE_BLOCKS 256
#define ICACHE_OFFSET_BITS 8
#define ICACHE_INDEX_BITS 8

/* VCACHE_BLOCKS_SIZE: Size of one victim-cache block (line).
* VCACHE_BLOCKS: Number of blocks (lines) in the victim cache.
*
* The victim cache is implemented as a small, fully associative cache.
* It is designed to serve as a temporary buffer for instruction cache blocks
* that were recently evicted from the instruction cache.
*
* Upon an instruction cache miss, the system first checks the victim cache
* for the corresponding data. If the data is found (a victim cache hit),
* the instruction cache block and the victim cache block are swapped.
* Conversely, when the instruction cache is being filled with new data,
* the evicted old data from the instruction cache block is simultaneously
* placed into the victim cache.
*/
#define VCACHE_BLOCK_SIZE ICACHE_BLOCKS_SIZE
#define VCACHE_BLOCKS 16

/* For power-of-two sizes, (size - 1) sets all low bits to 1,
* allowing fast extraction of an address.
*/
#define ICACHE_INDEX_MASK (ICACHE_BLOCKS - 1)
#define ICACHE_BLOCK_MASK (ICACHE_BLOCKS_SIZE - 1)
#define RV_PAGE_MASK (RV_PAGE_SIZE - 1)

typedef struct {
uint32_t tag;
const uint8_t *base;
bool valid;
} icache_block_t;

typedef icache_block_t victim_cache_block_t;

typedef struct {
icache_block_t i_block[ICACHE_BLOCKS];
victim_cache_block_t v_block[VCACHE_BLOCKS];
uint32_t v_used[VCACHE_BLOCKS];
} icache_t;

struct __hart_internal {
icache_t icache;
uint32_t x_regs[32];

/* LR reservation virtual address. last bit is 1 if valid */
Expand Down Expand Up @@ -106,7 +162,8 @@ struct __hart_internal {
*/
uint32_t exc_cause, exc_val;

mmu_fetch_cache_t cache_fetch;
/* 2-entry direct-mapped with hash-based indexing */
mmu_fetch_cache_t cache_fetch[2];
/* 8-set × 2-way set-associative cache with 3-bit parity hash indexing */
mmu_cache_set_t cache_load[8];
/* 8-set × 2-way set-associative cache for store operations */
Expand Down
Loading