diff --git a/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.ShiftRot.cs b/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.ShiftRot.cs index 0ea6e11b122f5d..af8fd82879585d 100644 --- a/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.ShiftRot.cs +++ b/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.ShiftRot.cs @@ -185,64 +185,48 @@ public static void RightShiftSelf(Span bits, int shift, out nuint carry) int back = BitsPerLimb - shift; - if (Vector128.IsHardwareAccelerated) - { - carry = bits[0] << back; + carry = bits[0] << back; - ref nuint start = ref MemoryMarshal.GetReference(bits); - int offset = 0; + Span remaining = bits; - while (Vector512.IsHardwareAccelerated && bits.Length - offset >= Vector512.Count + 1) - { - Vector512 current = Vector512.LoadUnsafe(ref start, (nuint)offset) >> shift; - Vector512 carries = Vector512.LoadUnsafe(ref start, (nuint)(offset + 1)) << back; - - Vector512 newValue = current | carries; + while (Vector512.IsHardwareAccelerated && remaining.Length >= Vector512.Count + 1) + { + Vector512 current = Vector512.Create(remaining) >> shift; + Vector512 carries = Vector512.Create(remaining.Slice(1)) << back; - Vector512.StoreUnsafe(newValue, ref start, (nuint)offset); - offset += Vector512.Count; - } + Vector512 newValue = current | carries; - while (Vector256.IsHardwareAccelerated && bits.Length - offset >= Vector256.Count + 1) - { - Vector256 current = Vector256.LoadUnsafe(ref start, (nuint)offset) >> shift; - Vector256 carries = Vector256.LoadUnsafe(ref start, (nuint)(offset + 1)) << back; + newValue.CopyTo(remaining); + remaining = remaining.Slice(Vector512.Count); + } - Vector256 newValue = current | carries; + while (Vector256.IsHardwareAccelerated && remaining.Length >= Vector256.Count + 1) + { + Vector256 current = Vector256.Create(remaining) >> shift; + Vector256 carries = Vector256.Create(remaining.Slice(1)) << back; - Vector256.StoreUnsafe(newValue, ref start, (nuint)offset); - offset += Vector256.Count; - } + Vector256 newValue = current | carries; - while (Vector128.IsHardwareAccelerated && bits.Length - offset >= Vector128.Count + 1) - { - Vector128 current = Vector128.LoadUnsafe(ref start, (nuint)offset) >> shift; - Vector128 carries = Vector128.LoadUnsafe(ref start, (nuint)(offset + 1)) << back; + newValue.CopyTo(remaining); + remaining = remaining.Slice(Vector256.Count); + } - Vector128 newValue = current | carries; + while (Vector128.IsHardwareAccelerated && remaining.Length >= Vector128.Count + 1) + { + Vector128 current = Vector128.Create(remaining) >> shift; + Vector128 carries = Vector128.Create(remaining.Slice(1)) << back; - Vector128.StoreUnsafe(newValue, ref start, (nuint)offset); - offset += Vector128.Count; - } + Vector128 newValue = current | carries; - nuint carry2 = 0; - for (int i = bits.Length - 1; i >= offset; i--) - { - nuint value = carry2 | bits[i] >> shift; - carry2 = bits[i] << back; - bits[i] = value; - } + newValue.CopyTo(remaining); + remaining = remaining.Slice(Vector128.Count); } - else + + for (int i = 0; i < remaining.Length - 1; i++) { - carry = 0; - for (int i = bits.Length - 1; i >= 0; i--) - { - nuint value = carry | bits[i] >> shift; - carry = bits[i] << back; - bits[i] = value; - } + remaining[i] = (remaining[i] >> shift) | (remaining[i + 1] << back); } + remaining[remaining.Length - 1] >>= shift; } } }