diff --git a/crates/ruvector-graph/src/graph.rs b/crates/ruvector-graph/src/graph.rs index 53e722aa4..ab1dd2f75 100644 --- a/crates/ruvector-graph/src/graph.rs +++ b/crates/ruvector-graph/src/graph.rs @@ -167,10 +167,11 @@ impl GraphDB { // Edge operations - /// Create an edge + /// Create an edge. + /// + /// Disk-first: persists the edge before registering it in memory and + /// indexes so a storage failure leaves the graph state untouched. pub fn create_edge(&self, edge: Edge) -> Result { - let id = edge.id.clone(); - // Verify nodes exist if !self.nodes.contains_key(&edge.from) || !self.nodes.contains_key(&edge.to) { return Err(crate::error::GraphError::NodeNotFound( @@ -178,19 +179,18 @@ impl GraphDB { )); } - // Update indexes - self.edge_type_index.add_edge(&edge); - self.adjacency_index.add_edge(&edge); - - // Insert into memory - self.edges.insert(id.clone(), edge.clone()); - - // Persist to storage if available + // Persist first — on failure, memory/indexes stay coherent. #[cfg(feature = "storage")] if let Some(storage) = &self.storage { storage.insert_edge(&edge)?; } + // Persist succeeded — register in indexes and memory. + let id = edge.id.clone(); + self.edge_type_index.add_edge(&edge); + self.adjacency_index.add_edge(&edge); + self.edges.insert(id.clone(), edge); + Ok(id) } @@ -199,49 +199,58 @@ impl GraphDB { self.edges.get(id.as_ref()).map(|entry| entry.clone()) } - /// Delete an edge + /// Delete an edge. + /// + /// Disk-first: persists the deletion before mutating in-memory state so a + /// storage failure leaves memory and indexes untouched. pub fn delete_edge(&self, id: impl AsRef) -> Result { - if let Some((_, edge)) = self.edges.remove(id.as_ref()) { - // Update indexes - self.edge_type_index.remove_edge(&edge); - self.adjacency_index.remove_edge(&edge); + let key = id.as_ref(); - // Delete from storage if available - #[cfg(feature = "storage")] - if let Some(storage) = &self.storage { - storage.delete_edge(id.as_ref())?; - } + let Some(edge) = self.edges.get(key).map(|e| e.clone()) else { + return Ok(false); + }; - Ok(true) - } else { - Ok(false) + #[cfg(feature = "storage")] + if let Some(storage) = &self.storage { + storage.delete_edge(key)?; } - } - /// Delete multiple edges (batch) - pub fn delete_edges_batch(&self, ids: &[impl AsRef]) -> Result { - let mut deleted = 0; - let mut edges_to_update = Vec::with_capacity(ids.len()); + self.edge_type_index.remove_edge(&edge); + self.adjacency_index.remove_edge(&edge); + self.edges.remove(key); - for id in ids { - let key: &str = id.as_ref(); - if let Some((_, edge)) = self.edges.remove(key) { - edges_to_update.push(edge); - deleted += 1; - } - } + Ok(true) + } - for edge in &edges_to_update { - self.edge_type_index.remove_edge(edge); - self.adjacency_index.remove_edge(edge); - } + /// Delete multiple edges (batch). + /// + /// Disk-first: persists the deletion before mutating in-memory state so a + /// storage failure leaves memory and indexes untouched (caller sees an + /// `Err` and the graph is still coherent). + pub fn delete_edges_batch(&self, ids: &[impl AsRef]) -> Result { + // Snapshot edges for later index updates without mutating memory. + let edges_to_remove = ids + .iter() + .filter_map(|id| self.edges.get(id.as_ref()).map(|e| e.clone())) + .collect::>(); + // Persist first — on failure, memory/indexes stay coherent. #[cfg(feature = "storage")] if let Some(storage) = &self.storage { let str_ids = ids.iter().map(|id| id.as_ref()).collect::>(); storage.delete_edges_batch(&str_ids)?; } + // Persist succeeded — apply in-memory removal. + let mut deleted = 0; + for edge in &edges_to_remove { + if self.edges.remove(edge.id.as_str()).is_some() { + self.edge_type_index.remove_edge(edge); + self.adjacency_index.remove_edge(edge); + deleted += 1; + } + } + Ok(deleted) } diff --git a/crates/sona/src/lora.rs b/crates/sona/src/lora.rs index 87aac544e..11c7e801c 100644 --- a/crates/sona/src/lora.rs +++ b/crates/sona/src/lora.rs @@ -271,14 +271,33 @@ impl MicroLoRA { (&self.down_proj, &self.up_proj) } - /// Set LoRA weights from external source (disk load, other system) + /// Set LoRA weights from external source (disk load, another system, etc.) + /// + /// Implements a **Transaction Pattern** for safe weight loading: + /// 1. **Validation Phase**: All checks (dimension + numerical integrity) run first + /// 2. **Commitment Phase**: Internal state updates ONLY if all validations pass + /// + /// This prevents "Model Explosion" — loading corrupted weights cannot corrupt + /// the internal state because validation happens before any mutation. /// /// # Arguments - /// * `down_proj` - Down projection weights (hidden_dim * rank) - /// * `up_proj` - Up projection weights (rank * hidden_dim) + /// * `down_proj` - Down projection matrix, shape `[hidden_dim * rank]` + /// * `up_proj` - Up projection matrix, shape `[rank * hidden_dim]` /// /// # Errors - /// Returns Err if dimensions don't match current rank/hidden_dim + /// Returns `Err` with descriptive message if: + /// - Dimension mismatch: `down_proj.len() != hidden_dim * rank` + /// - Dimension mismatch: `up_proj.len() != rank * hidden_dim` + /// - Numerical instability: `down_proj` contains `NaN` or `Inf` + /// - Numerical instability: `up_proj` contains `NaN` or `Inf` + /// + /// # Example + /// ``` + /// let mut lora = MicroLoRA::new(64, 2); + /// let down = vec![0.1f32; 64 * 2]; + /// let up = vec![0.2f32; 2 * 64]; + /// assert!(lora.set_weights(down, up).is_ok()); + /// ``` pub fn set_weights(&mut self, down_proj: Vec, up_proj: Vec) -> Result<(), String> { let expected_down = self.hidden_dim * self.rank; if down_proj.len() != expected_down { @@ -298,6 +317,16 @@ impl MicroLoRA { )); } + // Prevent "Model Explosion": Ensure no NaN or Infinite values exist. + // This is crucial when loading weights from unverified external sources. + if down_proj.iter().any(|&x| x.is_nan() || x.is_infinite()) { + return Err("Numerical error: down_proj contains NaN or Inf".to_string()); + } + + if up_proj.iter().any(|&x| x.is_nan() || x.is_infinite()) { + return Err("Numerical error: up_proj contains NaN or Inf".to_string()); + } + self.down_proj = down_proj; self.up_proj = up_proj; Ok(()) @@ -583,4 +612,66 @@ mod tests { assert!(result.is_err()); assert!(result.unwrap_err().contains("up_proj dimension mismatch")); } + + #[test] + fn test_set_weights_nan_in_down_proj() { + let mut lora = MicroLoRA::new(64, 2); + let mut down = vec![1.0f32; 64 * 2]; + down[10] = f32::NAN; + let up = vec![0.5f32; 2 * 64]; + + let result = lora.set_weights(down, up); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("NaN")); + } + + #[test] + fn test_set_weights_inf_in_down_proj() { + let mut lora = MicroLoRA::new(64, 2); + let mut down = vec![1.0f32; 64 * 2]; + down[5] = f32::INFINITY; + let up = vec![0.5f32; 2 * 64]; + + let result = lora.set_weights(down, up); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Inf")); + } + + #[test] + fn test_set_weights_nan_in_up_proj() { + let mut lora = MicroLoRA::new(64, 2); + let down = vec![1.0f32; 64 * 2]; + let mut up = vec![0.5f32; 2 * 64]; + up[20] = f32::NAN; + + let result = lora.set_weights(down, up); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("NaN")); + } + + #[test] + fn test_set_weights_inf_in_up_proj() { + let mut lora = MicroLoRA::new(64, 2); + let down = vec![1.0f32; 64 * 2]; + let mut up = vec![0.5f32; 2 * 64]; + up[30] = f32::NEG_INFINITY; + + let result = lora.set_weights(down, up); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Inf")); + } + + #[test] + fn test_set_weights_preserves_original_on_validation_failure() { + let mut lora = MicroLoRA::new(64, 2); + let original_down = lora.get_weights().0.clone(); + let original_up = lora.get_weights().1.clone(); + + let result = lora.set_weights(vec![1.0f32; 64 * 2], vec![0.5f32; 3 * 64]); + assert!(result.is_err()); + + let (down, up) = lora.get_weights(); + assert_eq!(down, &original_down); + assert_eq!(up, &original_up); + } }