Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 49 additions & 40 deletions crates/ruvector-graph/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,30 +167,30 @@ impl GraphDB {

// Edge operations

/// Create an edge
/// Create an edge.
///
/// Disk-first: persists the edge before registering it in memory and
/// indexes so a storage failure leaves the graph state untouched.
pub fn create_edge(&self, edge: Edge) -> Result<EdgeId> {
let id = edge.id.clone();

// Verify nodes exist
if !self.nodes.contains_key(&edge.from) || !self.nodes.contains_key(&edge.to) {
return Err(crate::error::GraphError::NodeNotFound(
"Source or target node not found".to_string(),
));
}

// Update indexes
self.edge_type_index.add_edge(&edge);
self.adjacency_index.add_edge(&edge);

// Insert into memory
self.edges.insert(id.clone(), edge.clone());

// Persist to storage if available
// Persist first — on failure, memory/indexes stay coherent.
#[cfg(feature = "storage")]
if let Some(storage) = &self.storage {
storage.insert_edge(&edge)?;
}

// Persist succeeded — register in indexes and memory.
let id = edge.id.clone();
self.edge_type_index.add_edge(&edge);
self.adjacency_index.add_edge(&edge);
self.edges.insert(id.clone(), edge);

Ok(id)
}

Expand All @@ -199,49 +199,58 @@ impl GraphDB {
self.edges.get(id.as_ref()).map(|entry| entry.clone())
}

/// Delete an edge
/// Delete an edge.
///
/// Disk-first: persists the deletion before mutating in-memory state so a
/// storage failure leaves memory and indexes untouched.
pub fn delete_edge(&self, id: impl AsRef<str>) -> Result<bool> {
if let Some((_, edge)) = self.edges.remove(id.as_ref()) {
// Update indexes
self.edge_type_index.remove_edge(&edge);
self.adjacency_index.remove_edge(&edge);
let key = id.as_ref();

// Delete from storage if available
#[cfg(feature = "storage")]
if let Some(storage) = &self.storage {
storage.delete_edge(id.as_ref())?;
}
let Some(edge) = self.edges.get(key).map(|e| e.clone()) else {
return Ok(false);
};

Ok(true)
} else {
Ok(false)
#[cfg(feature = "storage")]
if let Some(storage) = &self.storage {
storage.delete_edge(key)?;
}
}

/// Delete multiple edges (batch)
pub fn delete_edges_batch(&self, ids: &[impl AsRef<str>]) -> Result<usize> {
let mut deleted = 0;
let mut edges_to_update = Vec::with_capacity(ids.len());
self.edge_type_index.remove_edge(&edge);
self.adjacency_index.remove_edge(&edge);
self.edges.remove(key);

for id in ids {
let key: &str = id.as_ref();
if let Some((_, edge)) = self.edges.remove(key) {
edges_to_update.push(edge);
deleted += 1;
}
}
Ok(true)
}

for edge in &edges_to_update {
self.edge_type_index.remove_edge(edge);
self.adjacency_index.remove_edge(edge);
}
/// Delete multiple edges (batch).
///
/// Disk-first: persists the deletion before mutating in-memory state so a
/// storage failure leaves memory and indexes untouched (caller sees an
/// `Err` and the graph is still coherent).
pub fn delete_edges_batch(&self, ids: &[impl AsRef<str>]) -> Result<usize> {
// Snapshot edges for later index updates without mutating memory.
let edges_to_remove = ids
.iter()
.filter_map(|id| self.edges.get(id.as_ref()).map(|e| e.clone()))
.collect::<Vec<_>>();

// Persist first — on failure, memory/indexes stay coherent.
#[cfg(feature = "storage")]
if let Some(storage) = &self.storage {
let str_ids = ids.iter().map(|id| id.as_ref()).collect::<Vec<_>>();
storage.delete_edges_batch(&str_ids)?;
}

// Persist succeeded — apply in-memory removal.
let mut deleted = 0;
for edge in &edges_to_remove {
if self.edges.remove(edge.id.as_str()).is_some() {
self.edge_type_index.remove_edge(edge);
self.adjacency_index.remove_edge(edge);
deleted += 1;
}
}

Ok(deleted)
}

Expand Down
99 changes: 95 additions & 4 deletions crates/sona/src/lora.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,33 @@ impl MicroLoRA {
(&self.down_proj, &self.up_proj)
}

/// Set LoRA weights from external source (disk load, other system)
/// Set LoRA weights from external source (disk load, another system, etc.)
///
/// Implements a **Transaction Pattern** for safe weight loading:
/// 1. **Validation Phase**: All checks (dimension + numerical integrity) run first
/// 2. **Commitment Phase**: Internal state updates ONLY if all validations pass
///
/// This prevents "Model Explosion" — loading corrupted weights cannot corrupt
/// the internal state because validation happens before any mutation.
///
/// # Arguments
/// * `down_proj` - Down projection weights (hidden_dim * rank)
/// * `up_proj` - Up projection weights (rank * hidden_dim)
/// * `down_proj` - Down projection matrix, shape `[hidden_dim * rank]`
/// * `up_proj` - Up projection matrix, shape `[rank * hidden_dim]`
///
/// # Errors
/// Returns Err if dimensions don't match current rank/hidden_dim
/// Returns `Err` with descriptive message if:
/// - Dimension mismatch: `down_proj.len() != hidden_dim * rank`
/// - Dimension mismatch: `up_proj.len() != rank * hidden_dim`
/// - Numerical instability: `down_proj` contains `NaN` or `Inf`
/// - Numerical instability: `up_proj` contains `NaN` or `Inf`
///
/// # Example
/// ```
/// let mut lora = MicroLoRA::new(64, 2);
/// let down = vec![0.1f32; 64 * 2];
/// let up = vec![0.2f32; 2 * 64];
/// assert!(lora.set_weights(down, up).is_ok());
/// ```
pub fn set_weights(&mut self, down_proj: Vec<f32>, up_proj: Vec<f32>) -> Result<(), String> {
let expected_down = self.hidden_dim * self.rank;
if down_proj.len() != expected_down {
Expand All @@ -298,6 +317,16 @@ impl MicroLoRA {
));
}

// Prevent "Model Explosion": Ensure no NaN or Infinite values exist.
// This is crucial when loading weights from unverified external sources.
if down_proj.iter().any(|&x| x.is_nan() || x.is_infinite()) {
return Err("Numerical error: down_proj contains NaN or Inf".to_string());
}

if up_proj.iter().any(|&x| x.is_nan() || x.is_infinite()) {
return Err("Numerical error: up_proj contains NaN or Inf".to_string());
}

self.down_proj = down_proj;
self.up_proj = up_proj;
Ok(())
Expand Down Expand Up @@ -583,4 +612,66 @@ mod tests {
assert!(result.is_err());
assert!(result.unwrap_err().contains("up_proj dimension mismatch"));
}

#[test]
fn test_set_weights_nan_in_down_proj() {
let mut lora = MicroLoRA::new(64, 2);
let mut down = vec![1.0f32; 64 * 2];
down[10] = f32::NAN;
let up = vec![0.5f32; 2 * 64];

let result = lora.set_weights(down, up);
assert!(result.is_err());
assert!(result.unwrap_err().contains("NaN"));
}

#[test]
fn test_set_weights_inf_in_down_proj() {
let mut lora = MicroLoRA::new(64, 2);
let mut down = vec![1.0f32; 64 * 2];
down[5] = f32::INFINITY;
let up = vec![0.5f32; 2 * 64];

let result = lora.set_weights(down, up);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Inf"));
}

#[test]
fn test_set_weights_nan_in_up_proj() {
let mut lora = MicroLoRA::new(64, 2);
let down = vec![1.0f32; 64 * 2];
let mut up = vec![0.5f32; 2 * 64];
up[20] = f32::NAN;

let result = lora.set_weights(down, up);
assert!(result.is_err());
assert!(result.unwrap_err().contains("NaN"));
}

#[test]
fn test_set_weights_inf_in_up_proj() {
let mut lora = MicroLoRA::new(64, 2);
let down = vec![1.0f32; 64 * 2];
let mut up = vec![0.5f32; 2 * 64];
up[30] = f32::NEG_INFINITY;

let result = lora.set_weights(down, up);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Inf"));
}

#[test]
fn test_set_weights_preserves_original_on_validation_failure() {
let mut lora = MicroLoRA::new(64, 2);
let original_down = lora.get_weights().0.clone();
let original_up = lora.get_weights().1.clone();

let result = lora.set_weights(vec![1.0f32; 64 * 2], vec![0.5f32; 3 * 64]);
assert!(result.is_err());

let (down, up) = lora.get_weights();
assert_eq!(down, &original_down);
assert_eq!(up, &original_up);
}
}
Loading