diff --git a/src/uu/stty/src/stty.rs b/src/uu/stty/src/stty.rs index fdeee252df3..42432c22c2b 100644 --- a/src/uu/stty/src/stty.rs +++ b/src/uu/stty/src/stty.rs @@ -10,7 +10,7 @@ // spell-checker:ignore isig icanon iexten echoe crterase echok echonl noflsh xcase tostop echoprt prterase echoctl ctlecho echoke crtkill flusho extproc // spell-checker:ignore lnext rprnt susp swtch vdiscard veof veol verase vintr vkill vlnext vquit vreprint vstart vstop vsusp vswtc vwerase werase // spell-checker:ignore sigquit sigtstp -// spell-checker:ignore cbreak decctlq evenp litout oddp tcsadrain +// spell-checker:ignore cbreak decctlq evenp litout oddp tcsadrain exta extb mod flags; @@ -23,6 +23,7 @@ use nix::sys::termios::{ Termios, cfgetospeed, cfsetospeed, tcgetattr, tcsetattr, }; use nix::{ioctl_read_bad, ioctl_write_ptr_bad}; +use std::cmp::Ordering; use std::fs::File; use std::io::{self, Stdout, stdout}; use std::num::IntErrorKind; @@ -563,7 +564,69 @@ fn string_to_combo(arg: &str) -> Option<&str> { .map(|_| arg) } +/// Parse and round a baud rate value using GNU stty's custom rounding algorithm. +/// +/// Accepts decimal values with the following rounding rules: +/// - If first digit after decimal > 5: round up +/// - If first digit after decimal < 5: round down +/// - If first digit after decimal == 5: +/// - If followed by any non-zero digit: round up +/// - If followed only by zeros (or nothing): banker's rounding (round to nearest even) +/// +/// Examples: "9600.49" -> 9600, "9600.51" -> 9600, "9600.5" -> 9600 (even), "9601.5" -> 9602 (even) +/// TODO: there are two special cases "exta" → B19200 and "extb" → B38400 +fn parse_baud_with_rounding(normalized: &str) -> Option { + let (int_part, frac_part) = match normalized.split_once('.') { + Some((i, f)) => (i, Some(f)), + None => (normalized, None), + }; + + let mut value = int_part.parse::().ok()?; + + if let Some(frac) = frac_part { + let mut chars = frac.chars(); + let first_digit = chars.next()?.to_digit(10)?; + + // Validate all remaining chars are digits + let rest: Vec<_> = chars.collect(); + if !rest.iter().all(|c| c.is_ascii_digit()) { + return None; + } + + match first_digit.cmp(&5) { + Ordering::Greater => value += 1, + Ordering::Equal => { + // Check if any non-zero digit follows + if rest.iter().any(|&c| c != '0') { + value += 1; + } else { + // Banker's rounding: round to nearest even + value += value & 1; + } + } + Ordering::Less => {} // Round down, already validated + } + } + + Some(value) +} + fn string_to_baud(arg: &str) -> Option> { + // Reject invalid formats + if arg != arg.trim_end() + || arg.trim().starts_with('-') + || arg.trim().starts_with("++") + || arg.contains('E') + || arg.contains('e') + || arg.matches('.').count() > 1 + { + return None; + } + + let normalized = arg.trim().trim_start_matches('+'); + let normalized = normalized.strip_suffix('.').unwrap_or(normalized); + let value = parse_baud_with_rounding(normalized)?; + // BSDs use a u32 for the baud rate, so any decimal number applies. #[cfg(any( target_os = "freebsd", @@ -573,9 +636,7 @@ fn string_to_baud(arg: &str) -> Option> { target_os = "netbsd", target_os = "openbsd" ))] - if let Ok(n) = arg.parse::() { - return Some(AllFlags::Baud(n)); - } + return Some(AllFlags::Baud(value)); #[cfg(not(any( target_os = "freebsd", @@ -585,12 +646,14 @@ fn string_to_baud(arg: &str) -> Option> { target_os = "netbsd", target_os = "openbsd" )))] - for (text, baud_rate) in BAUD_RATES { - if *text == arg { - return Some(AllFlags::Baud(*baud_rate)); + { + for (text, baud_rate) in BAUD_RATES { + if text.parse::().ok() == Some(value) { + return Some(AllFlags::Baud(*baud_rate)); + } } + None } - None } /// return `Some(flag)` if the input is a valid flag, `None` if not diff --git a/tests/by-util/test_stty.rs b/tests/by-util/test_stty.rs index d6870d48fff..9626c140656 100644 --- a/tests/by-util/test_stty.rs +++ b/tests/by-util/test_stty.rs @@ -194,6 +194,24 @@ fn invalid_baud_setting() { .args(&["ospeed", "995"]) .fails() .stderr_contains("invalid ospeed '995'"); + + for speed in &[ + "9599..", "9600..", "9600.5.", "9600.50.", "9600.0.", "++9600", "0x2580", "96E2", "9600,0", + "9600.0 ", + ] { + new_ucmd!().args(&["ispeed", speed]).fails(); + } +} + +#[test] +#[cfg(unix)] +fn valid_baud_formats() { + let (path, _controller, _replica) = pty_path(); + for speed in &[" +9600", "9600.49", "9600.50", "9599.51", " 9600."] { + new_ucmd!() + .args(&["--file", &path, "ispeed", speed]) + .succeeds(); + } } #[test]