Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 87 additions & 65 deletions cmov/src/portable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
//! optimizer potentially inserting branches.

use crate::{Cmov, CmovEq, Condition};
use core::ops::{BitAnd, BitOr, Not};

// Uses `Cmov` impl for `u32`
impl Cmov for u16 {
#[inline]
fn cmovnz(&mut self, value: &u16, condition: Condition) {
Expand All @@ -24,6 +26,7 @@ impl Cmov for u16 {
}
}

// Uses `CmovEq` impl for `u32`
impl CmovEq for u16 {
#[inline]
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
Expand All @@ -39,59 +42,103 @@ impl CmovEq for u16 {
impl Cmov for u32 {
#[inline]
fn cmovnz(&mut self, value: &Self, condition: Condition) {
let mask = masknz32(condition);
*self = (*self & !mask) | (*value & mask);
*self = masksel(*self, *value, masknz32(condition.into()));
}

#[inline]
fn cmovz(&mut self, value: &Self, condition: Condition) {
let mask = masknz32(condition);
*self = (*self & mask) | (*value & !mask);
*self = masksel(*self, *value, !masknz32(condition.into()));
}
}

impl CmovEq for u32 {
#[inline]
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
let ne = testne32(*self, *rhs);
output.cmovnz(&input, ne);
output.cmovnz(&input, testne32(*self, *rhs));
}

#[inline]
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
let eq = testeq32(*self, *rhs);
output.cmovnz(&input, eq);
output.cmovnz(&input, testeq32(*self, *rhs));
}
}

impl Cmov for u64 {
#[inline]
fn cmovnz(&mut self, value: &Self, condition: Condition) {
let mask = masknz64(condition);
*self = (*self & !mask) | (*value & mask);
*self = masksel(*self, *value, masknz64(condition.into()));
}

#[inline]
fn cmovz(&mut self, value: &Self, condition: Condition) {
let mask = masknz64(condition);
*self = (*self & mask) | (*value & !mask);
*self = masksel(*self, *value, !masknz64(condition.into()));
}
}

impl CmovEq for u64 {
#[inline]
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
let ne = testne64(*self, *rhs);
output.cmovnz(&input, ne);
output.cmovnz(&input, testne64(*self, *rhs));
}

#[inline]
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
let eq = testeq64(*self, *rhs);
output.cmovnz(&input, eq);
output.cmovnz(&input, testeq64(*self, *rhs));
}
}

/// Return a [`u32::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
#[cfg(not(target_arch = "arm"))]
fn masknz32(condition: u32) -> u32 {
testnz32(condition).wrapping_neg()
}

/// Return a [`u64::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
#[cfg(not(target_arch = "arm"))]
fn masknz64(condition: u64) -> u64 {
testnz64(condition).wrapping_neg()
}

/// Optimized mask generation for ARM32 targets.
///
/// This is written in assembly both for performance and because we've had problematic code
/// generation in this routine in the past which lead to the insertion of a branch, which using
/// assembly should guarantee won't happen again in the future (CVE-2026-23519).
#[cfg(target_arch = "arm")]
fn masknz32(condition: u32) -> u32 {
let mut mask = condition;
unsafe {
core::arch::asm!(
"rsbs {0}, {0}, #0", // Reverse subtract
"sbcs {0}, {0}, {0}", // Subtract with carry, setting flags
inout(reg) mask,
options(nostack, nomem),
);
}
mask
}

/// 64-bit wrapper for targets that implement 32-bit mask generation in assembly.
#[cfg(target_arch = "arm")]
fn masknz64(condition: u64) -> u64 {
let lo = masknz32((condition & 0xFFFF_FFFF) as u32);
let hi = masknz32((condition >> 32) as u32);
let mask = (lo | hi) as u64;
mask | mask << 32
}

/// Given a supplied mask of `0` or all 1-bits (i.e. `u*::MAX`), select `a` if the mask is all-zeros
/// and `b` if the mask is all-ones.
///
/// This function shouldn't be used with a mask that isn't `0` or `u*::MAX`.
#[inline]
fn masksel<T>(a: T, b: T, mask: T) -> T
where
T: BitAnd<Output = T> + BitOr<Output = T> + Copy + Not<Output = T>,
{
(a & !mask) | (b & mask)
}

/// Returns `1` if `x` is equal to `y`, otherwise returns `0` (32-bit version)
fn testeq32(x: u32, y: u32) -> Condition {
testne32(x, y) ^ 1
Expand Down Expand Up @@ -120,46 +167,37 @@ fn testnz32(mut x: u32) -> u32 {

/// Returns `0` if `x` is `0`, otherwise returns `1` (64-bit version)
fn testnz64(mut x: u64) -> u64 {
x |= x.wrapping_neg(); // MSB now set if non-zero
x |= x.wrapping_neg(); // MSB now set if non-zero (or unset if zero)
core::hint::black_box(x >> (u64::BITS - 1)) // Extract MSB
}

/// Return a [`u32::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
#[cfg(not(target_arch = "arm"))]
fn masknz32(condition: Condition) -> u32 {
testnz32(condition.into()).wrapping_neg()
}

/// Return a [`u64::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
#[cfg(not(target_arch = "arm"))]
fn masknz64(condition: Condition) -> u64 {
testnz64(condition.into()).wrapping_neg()
}
#[cfg(test)]
mod tests {
#[test]
fn masknz32() {
assert_eq!(super::masknz32(0), 0);
for i in 1..=u8::MAX {
assert_eq!(super::masknz32(i.into()), u32::MAX);
}
}

/// Optimized mask generation for ARM32 targets.
#[cfg(target_arch = "arm")]
fn masknz32(condition: u8) -> u32 {
let mut out = condition as u32;
unsafe {
core::arch::asm!(
"rsbs {0}, {0}, #0", // Reverse subtract
"sbcs {0}, {0}, {0}", // Subtract with carry, setting flags
inout(reg) out,
options(nostack, nomem),
);
#[test]
fn masknz64() {
assert_eq!(super::masknz64(0), 0);
for i in 1..=u8::MAX {
assert_eq!(super::masknz64(i.into()), u64::MAX);
}
}
out
}

/// 64-bit wrapper for targets that implement 32-bit mask generation in assembly.
#[cfg(target_arch = "arm")]
fn masknz64(condition: u8) -> u64 {
let mask = masknz32(condition) as u64;
mask | mask << 32
}
#[test]
fn masksel() {
assert_eq!(super::masksel(23u8, 42u8, 0u8), 23u8);
assert_eq!(super::masksel(23u8, 42u8, u8::MAX), 42u8);

assert_eq!(super::masksel(17u32, 101077u32, 0u32), 17u32);
assert_eq!(super::masksel(17u32, 101077u32, u32::MAX), 101077u32);
}

#[cfg(test)]
mod tests {
#[test]
fn testeq32() {
assert_eq!(super::testeq32(0, 0), 1);
Expand Down Expand Up @@ -219,20 +257,4 @@ mod tests {
assert_eq!(super::testnz64(i as u64), 1);
}
}

#[test]
fn masknz32() {
assert_eq!(super::masknz32(0), 0);
for i in 1..=u8::MAX {
assert_eq!(super::masknz32(i), u32::MAX);
}
}

#[test]
fn masknz64() {
assert_eq!(super::masknz64(0), 0);
for i in 1..=u8::MAX {
assert_eq!(super::masknz64(i), u64::MAX);
}
}
}