From 93903a1baa960ec73e65b88bb8f7c1f489ab82b1 Mon Sep 17 00:00:00 2001 From: tompng Date: Sun, 17 Aug 2025 15:07:48 +0900 Subject: [PATCH 1/5] Implement faster multiplication using Number Theoretic Transform Performs ntt with three primes (29<<27|1, 26<<27|1, 24<<27|1) --- bigdecimal.gemspec | 1 + ext/bigdecimal/bigdecimal.c | 39 ++++++ ext/bigdecimal/ntt.h | 191 +++++++++++++++++++++++++++ test/bigdecimal/test_vp_operation.rb | 13 ++ 4 files changed, 244 insertions(+) create mode 100644 ext/bigdecimal/ntt.h diff --git a/bigdecimal.gemspec b/bigdecimal.gemspec index b6ef8fd9..2c1550cd 100644 --- a/bigdecimal.gemspec +++ b/bigdecimal.gemspec @@ -46,6 +46,7 @@ Gem::Specification.new do |s| ext/bigdecimal/feature.h ext/bigdecimal/missing.c ext/bigdecimal/missing.h + ext/bigdecimal/ntt.h ext/bigdecimal/missing/dtoa.c ext/bigdecimal/static_assert.h ] diff --git a/ext/bigdecimal/bigdecimal.c b/ext/bigdecimal/bigdecimal.c index d9247790..1928b74c 100644 --- a/ext/bigdecimal/bigdecimal.c +++ b/ext/bigdecimal/bigdecimal.c @@ -33,6 +33,12 @@ #define BIGDECIMAL_VERSION "3.3.1" +#if SIZEOF_DECDIG == 4 +#define USE_NTT_MULTIPLICATION 1 +#include "ntt.h" +#define NTT_MULTIPLICATION_THRESHOLD 100 +#endif + /* #define ENABLE_NUMERIC_STRING */ #define SIGNED_VALUE_MAX INTPTR_MAX @@ -3281,6 +3287,25 @@ BigDecimal_vpmult(VALUE self, VALUE v) { RB_GC_GUARD(b.bigdecimal); return c.bigdecimal; } + +#if SIZEOF_DECDIG == 4 +VALUE +BigDecimal_nttmult(VALUE self, VALUE v) { + BDVALUE a,b,c; + a = GetBDValueMust(self); + b = GetBDValueMust(v); + c = NewZeroWrap(1, VPMULT_RESULT_PREC(a.real, b.real) * BASE_FIG); + ntt_multiply(a.real->Prec, b.real->Prec, a.real->frac, b.real->frac, c.real->frac); + VpSetSign(c.real, a.real->sign * b.real->sign); + c.real->exponent = a.real->exponent + b.real->exponent; + c.real->Prec = a.real->Prec + b.real->Prec; + VpNmlz(c.real); + RB_GC_GUARD(a.bigdecimal); + RB_GC_GUARD(b.bigdecimal); + return c.bigdecimal; +} +#endif + #endif /* BIGDECIMAL_USE_VP_TEST_METHODS */ /* Document-class: BigDecimal @@ -3653,6 +3678,9 @@ Init_bigdecimal(void) #ifdef BIGDECIMAL_USE_VP_TEST_METHODS rb_define_method(rb_cBigDecimal, "vpdivd", BigDecimal_vpdivd, 2); rb_define_method(rb_cBigDecimal, "vpmult", BigDecimal_vpmult, 1); +#ifdef USE_NTT_MULTIPLICATION + rb_define_method(rb_cBigDecimal, "nttmult", BigDecimal_nttmult, 1); +#endif #endif /* BIGDECIMAL_USE_VP_TEST_METHODS */ #define ROUNDING_MODE(i, name, value) \ @@ -4935,6 +4963,15 @@ VpMult(Real *c, Real *a, Real *b) c->exponent = a->exponent; /* set exponent */ VpSetSign(c, VpGetSign(a) * VpGetSign(b)); /* set sign */ if (!AddExponent(c, b->exponent)) return 0; + +#ifdef USE_NTT_MULTIPLICATION + if (b->Prec >= NTT_MULTIPLICATION_THRESHOLD) { + ntt_multiply((uint32_t)a->Prec, (uint32_t)b->Prec, a->frac, b->frac, c->frac); + c->Prec = a->Prec + b->Prec; + goto Cleanup; + } +#endif + carry = 0; nc = ind_c = MxIndAB; memset(c->frac, 0, (nc + 1) * sizeof(DECDIG)); /* Initialize c */ @@ -4981,6 +5018,8 @@ VpMult(Real *c, Real *a, Real *b) } } } + +Cleanup: VpNmlz(c); Exit: diff --git a/ext/bigdecimal/ntt.h b/ext/bigdecimal/ntt.h new file mode 100644 index 00000000..941f23f7 --- /dev/null +++ b/ext/bigdecimal/ntt.h @@ -0,0 +1,191 @@ +// NTT (Number Theoretic Transform) implementation for BigDecimal multiplication + +#define NTT_PRIMITIVE_ROOT 17 +#define NTT_PRIME_BASE1 24 +#define NTT_PRIME_BASE2 26 +#define NTT_PRIME_BASE3 29 +#define NTT_PRIME_SHIFT 27 +#define NTT_PRIME1 (((uint32_t)NTT_PRIME_BASE1 << NTT_PRIME_SHIFT) | 1) +#define NTT_PRIME2 (((uint32_t)NTT_PRIME_BASE2 << NTT_PRIME_SHIFT) | 1) +#define NTT_PRIME3 (((uint32_t)NTT_PRIME_BASE3 << NTT_PRIME_SHIFT) | 1) +#define MAX_NTT32_BITS 27 +#define NTT_DECDIG_BASE 1000000000 + +// Calculates base**ex % mod +static uint32_t +mod_pow(uint32_t base, uint32_t ex, uint32_t mod) { + uint32_t res = 1; + uint32_t bit = 1; + while (true) { + if (ex & bit) { + ex ^= bit; + res = ((uint64_t)res * base) % mod; + } + if (!ex) break; + base = ((uint64_t)base * base) % mod; + bit <<= 1; + } + return res; +} + +// Recursively performs butterfly operations of NTT +static void +ntt_recursive(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int depth, uint32_t r, uint32_t prime) { + if (depth > 0) { + ntt_recursive(size_bits, input, tmp, output, depth - 1, ((uint64_t)r * r) % prime, prime); + } else { + tmp = input; + } + uint32_t size_half = (uint32_t)1 << (size_bits - 1); + uint32_t stride = (uint32_t)1 << (size_bits - depth - 1); + uint32_t n = size_half / stride; + uint32_t rn = 1, rm = prime - 1; + uint32_t idx = 0; + for (uint32_t i = 0; i < n; i++) { + uint32_t j = i * 2 * stride; + for (uint32_t k = 0; k < stride; k++, j++, idx++) { + uint32_t a = tmp[j], b = tmp[j + stride]; + output[idx] = (a + (uint64_t)rn * b) % prime; + output[idx + size_half] = (a + (uint64_t)rm * b) % prime; + } + rn = ((uint64_t)rn * r) % prime; + rm = ((uint64_t)rm * r) % prime; + } +} + +/* Perform NTT on input array. + * base, shift: Represent the prime number as (base << shift | 1) + * r_base: Primitive root of unity modulo prime + * size_bits: log2 of the size of the input array. Should be less or equal to shift + * input: input array of size (1 << size_bits) + */ +static void +ntt(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int r_base, int base, int shift, int dir) { + uint32_t size = (uint32_t)1 << size_bits; + uint32_t prime = ((uint32_t)base << shift) | 1; + + // rmax**(1 << shift) % prime == 1 + // r**size % prime == 1 + uint32_t rmax = mod_pow(r_base, base, prime); + uint32_t r = mod_pow(rmax, (uint32_t)1 << (shift - size_bits), prime); + + if (dir < 0) r = mod_pow(r, prime - 2, prime); + ntt_recursive(size_bits, input, output, tmp, size_bits - 1, r, prime); + if (dir < 0) { + uint32_t n_inv = mod_pow((uint32_t)size, prime - 2, prime); + for (uint32_t i = 0; i < size; i++) { + output[i] = ((uint64_t)output[i] * n_inv) % prime; + } + } +} + +/* Calculate c that satisfies: c % PRIME1 == mod1 && c % PRIME2 == mod2 && c % PRIME3 == mod3 + * c = (mod1 * 35002755423056150739595925972 + mod2 * 14584479687667766215746868453 + mod3 * 37919651490985126265126719818) % (PRIME1 * PRIME2 * PRIME3) + * Assume c <= 999999999**2*(1<<27) + */ +static inline void +mod_restore_prime_24_26_29_shift_27(uint32_t mod1, uint32_t mod2, uint32_t mod3, uint32_t *digits) { + // Use mixed radix notation to eliminate modulo by PRIME1 * PRIME2 * PRIME3 + // [DIG0, DIG1, DIG2] = DIG0 + DIG1 * PRIME1 + DIG2 * PRIME1 * PRIME2 + // DIG0: 0...PRIME1, DIG1: 0...PRIME2, DIG2: 0...PRIME3 + // 35002755423056150739595925972 = [1, 3489660916, 3113851359] + // 14584479687667766215746868453 = [0, 13, 1297437912] + // 37919651490985126265126719818 = [0, 0, 3373338954] + uint64_t c0 = mod1; + uint64_t c1 = (uint64_t)mod2 * 13 + (uint64_t)mod1 * 3489660916; + uint64_t c2 = (uint64_t)mod3 * 3373338954 % NTT_PRIME3 + (uint64_t)mod2 * 1297437912 % NTT_PRIME3 + (uint64_t)mod1 * 3113851359 % NTT_PRIME3; + c2 += c1 / NTT_PRIME2; + c1 %= NTT_PRIME2; + c2 %= NTT_PRIME3; + // Base conversion. c fits in 3 digits. + c1 += c2 % NTT_DECDIG_BASE * NTT_PRIME2; + c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1; + c1 /= NTT_DECDIG_BASE; + digits[0] = c0 % NTT_DECDIG_BASE; + c0 /= NTT_DECDIG_BASE; + c1 += c2 / NTT_DECDIG_BASE % NTT_DECDIG_BASE * NTT_PRIME2; + c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1; + c1 /= NTT_DECDIG_BASE; + digits[1] = c0 % NTT_DECDIG_BASE; + digits[2] = (uint32_t)(c0 / NTT_DECDIG_BASE + c1 % NTT_DECDIG_BASE * NTT_PRIME1); +} + +/* + * NTT multiplication + * Uses three NTTs with mod (24 << 27 | 1), (26 << 27 | 1), and (29 << 27 | 1) + */ +static void +ntt_multiply(size_t a_size, size_t b_size, uint32_t *a, uint32_t *b, uint32_t *c) { + if (a_size < b_size) { + ntt_multiply(b_size, a_size, b, a, c); + return; + } + + int b_bits = 0; + while (((uint32_t)1 << b_bits) < (uint32_t)b_size) b_bits++; + int ntt_size_bits = b_bits + 1; + if (ntt_size_bits > MAX_NTT32_BITS) { + rb_raise(rb_eArgError, "Multiply size too large"); + } + + // To calculate large_a * small_b faster, split into several batches. + uint32_t ntt_size = (uint32_t)1 << ntt_size_bits; + uint32_t batch_size = ntt_size - (uint32_t)b_size; + uint32_t batch_count = (uint32_t)((a_size + batch_size - 1) / batch_size); + + uint32_t *mem = ruby_xcalloc(sizeof(uint32_t), ntt_size * 9); + uint32_t *ntt1 = mem; + uint32_t *ntt2 = mem + ntt_size; + uint32_t *ntt3 = mem + ntt_size * 2; + uint32_t *tmp1 = mem + ntt_size * 3; + uint32_t *tmp2 = mem + ntt_size * 4; + uint32_t *tmp3 = mem + ntt_size * 5; + uint32_t *conv1 = mem + ntt_size * 6; + uint32_t *conv2 = mem + ntt_size * 7; + uint32_t *conv3 = mem + ntt_size * 8; + + // Calculate NTT for b in three primes. Result is reused for each batch of a. + memcpy(tmp1, b, b_size * sizeof(uint32_t)); + memset(tmp1 + b_size, 0, (ntt_size - b_size) * sizeof(uint32_t)); + ntt(ntt_size_bits, tmp1, ntt1, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1); + ntt(ntt_size_bits, tmp1, ntt2, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1); + ntt(ntt_size_bits, tmp1, ntt3, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1); + + memset(c, 0, (a_size + b_size) * sizeof(uint32_t)); + for (uint32_t idx = 0; idx < batch_count; idx++) { + uint32_t len = idx == batch_count - 1 ? (uint32_t)a_size - idx * batch_size : batch_size; + memcpy(tmp1, a + idx * batch_size, len * sizeof(uint32_t)); + memset(tmp1 + len, 0, (ntt_size - len) * sizeof(uint32_t)); + // Calculate convolution for this batch in three primes + ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1); + for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt1[i]) % NTT_PRIME1; + ntt(ntt_size_bits, tmp2, conv1, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, -1); + ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1); + for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt2[i]) % NTT_PRIME2; + ntt(ntt_size_bits, tmp2, conv2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, -1); + ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1); + for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt3[i]) % NTT_PRIME3; + ntt(ntt_size_bits, tmp2, conv3, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, -1); + + // Restore the original convolution value from three convolutions calculated in three primes. + // Each convolution value is maximum 999999999**2*(1<<27)/2 + for (uint32_t i = 0; i < ntt_size; i++) { + uint32_t dig[3]; + mod_restore_prime_24_26_29_shift_27(conv1[i], conv2[i], conv3[i], dig); + // Maximum values of dig[0], dig[1], and dig[2] are 999999999, 999999999 and 67108863 respectively + // Maximum overlapped sum (considering overlaps between 2 batches) is less than 4134217722 + // so this sum doesn't overflow uint32_t. + for (int j = 0; j < 3; j++) { + // Index check: if dig[j] is non-zero, assign index is within valid range. + if (dig[j]) c[idx * batch_size + i + 1 - j] += dig[j]; + } + } + } + uint32_t carry = 0; + for (int32_t i = (uint32_t)(a_size + b_size - 1); i >= 0; i--) { + uint32_t v = c[i] + carry; + c[i] = v % NTT_DECDIG_BASE; + carry = v / NTT_DECDIG_BASE; + } + ruby_xfree(mem); +} diff --git a/test/bigdecimal/test_vp_operation.rb b/test/bigdecimal/test_vp_operation.rb index 075df0b6..fd6774a4 100644 --- a/test/bigdecimal/test_vp_operation.rb +++ b/test/bigdecimal/test_vp_operation.rb @@ -13,6 +13,10 @@ def setup end end + def ntt_mult_available? + BASE_FIG == 9 + end + def test_vpmult assert_equal(BigDecimal('121932631112635269'), BigDecimal('123456789').vpmult(BigDecimal('987654321'))) assert_equal(BigDecimal('12193263.1112635269'), BigDecimal('123.456789').vpmult(BigDecimal('98765.4321'))) @@ -21,6 +25,15 @@ def test_vpmult assert_equal(BigDecimal("#{x * y}e-300"), BigDecimal("#{x}e-100").vpmult(BigDecimal("#{y}e-200"))) end + def test_nttmult + omit 'NTT multiplication is only available for 32-bit DECDIG' unless ntt_mult_available? + [*1..32].repeated_permutation(2) do |a, b| + x = BigDecimal(10 ** (BASE_FIG * a) / 7) + y = BigDecimal(10 ** (BASE_FIG * b) / 13) + assert_equal(x.to_i * y.to_i, x.nttmult(y)) + end + end + def test_vpdivd # a[0] > b[0] # XXXX_YYYY_ZZZZ / 1111 #=> 000X_000Y_000Z From 7d598943aad7cee7c0512fd93915696bb36692a4 Mon Sep 17 00:00:00 2001 From: tompng Date: Tue, 19 Aug 2025 01:19:57 +0900 Subject: [PATCH 2/5] Implement Newton-Raphson division Improve performance of huge divisions --- bigdecimal.gemspec | 1 + ext/bigdecimal/bigdecimal.c | 47 +++++-- ext/bigdecimal/bigdecimal.h | 26 ++++ ext/bigdecimal/div.h | 192 +++++++++++++++++++++++++++ test/bigdecimal/test_vp_operation.rb | 107 +++++++++++---- 5 files changed, 334 insertions(+), 39 deletions(-) create mode 100644 ext/bigdecimal/div.h diff --git a/bigdecimal.gemspec b/bigdecimal.gemspec index 2c1550cd..6b20ac08 100644 --- a/bigdecimal.gemspec +++ b/bigdecimal.gemspec @@ -43,6 +43,7 @@ Gem::Specification.new do |s| ext/bigdecimal/bigdecimal.c ext/bigdecimal/bigdecimal.h ext/bigdecimal/bits.h + ext/bigdecimal/div.h ext/bigdecimal/feature.h ext/bigdecimal/missing.c ext/bigdecimal/missing.h diff --git a/ext/bigdecimal/bigdecimal.c b/ext/bigdecimal/bigdecimal.c index 1928b74c..5c249270 100644 --- a/ext/bigdecimal/bigdecimal.c +++ b/ext/bigdecimal/bigdecimal.c @@ -29,6 +29,7 @@ #endif #include "bits.h" +#include "div.h" #include "static_assert.h" #define BIGDECIMAL_VERSION "3.3.1" @@ -37,6 +38,7 @@ #define USE_NTT_MULTIPLICATION 1 #include "ntt.h" #define NTT_MULTIPLICATION_THRESHOLD 100 +#define NEWTON_RAPHSON_DIVISION_THRESHOLD 200 #endif /* #define ENABLE_NUMERIC_STRING */ @@ -81,11 +83,6 @@ static struct { uint8_t mode; } rbd_rounding_modes[RBD_NUM_ROUNDING_MODES]; -typedef struct { - VALUE bigdecimal; - Real *real; -} BDVALUE; - typedef struct { VALUE bigdecimal_or_nil; Real *real_or_null; @@ -213,7 +210,6 @@ rbd_allocate_struct_zero(int sign, size_t const digits) static unsigned short VpGetException(void); static void VpSetException(unsigned short f); static void VpCheckException(Real *p, bool always); -static int AddExponent(Real *a, SIGNED_VALUE n); static VALUE CheckGetValue(BDVALUE v); static void VpInternalRound(Real *c, size_t ixDigit, DECDIG vPrev, DECDIG v); static int VpLimitRound(Real *c, size_t ixDigit); @@ -1118,9 +1114,6 @@ BigDecimal_check_num(Real *p) VpCheckException(p, true); } -static VALUE BigDecimal_fix(VALUE self); -static VALUE BigDecimal_split(VALUE self); - /* Returns the value as an Integer. * * If the BigDecimal is infinity or NaN, raises FloatDomainError. @@ -3263,19 +3256,39 @@ BigDecimal_literal(const char *str) #ifdef BIGDECIMAL_USE_VP_TEST_METHODS VALUE -BigDecimal_vpdivd(VALUE self, VALUE r, VALUE cprec) { - BDVALUE a,b,c,d; +BigDecimal_vpdivd_generic(VALUE self, VALUE r, VALUE cprec, void (*vpdivd_func)(Real*, Real*, Real*, Real*)) { + BDVALUE a, b, c, d; size_t cn = NUM2INT(cprec); a = GetBDValueMust(self); b = GetBDValueMust(r); c = NewZeroWrap(1, cn * BASE_FIG); d = NewZeroWrap(1, VPDIVD_REM_PREC(a.real, b.real, c.real) * BASE_FIG); - VpDivd(c.real, d.real, a.real, b.real); + vpdivd_func(c.real, d.real, a.real, b.real); RB_GC_GUARD(a.bigdecimal); RB_GC_GUARD(b.bigdecimal); return rb_assoc_new(c.bigdecimal, d.bigdecimal); } +void +VpDivdNormal(Real *c, Real *r, Real *a, Real *b) { + VpDivd(c, r, a, b); +} + +VALUE +BigDecimal_vpdivd(VALUE self, VALUE r, VALUE cprec) { + return BigDecimal_vpdivd_generic(self, r, cprec, VpDivdNormal); +} + +VALUE +BigDecimal_vpdivd_newton(VALUE self, VALUE r, VALUE cprec) { + return BigDecimal_vpdivd_generic(self, r, cprec, VpDivdNewton); +} + +VALUE +BigDecimal_newton_raphson_inverse(VALUE self, VALUE prec) { + return newton_raphson_inverse(self, NUM2SIZET(prec)); +} + VALUE BigDecimal_vpmult(VALUE self, VALUE v) { BDVALUE a,b,c; @@ -3677,6 +3690,8 @@ Init_bigdecimal(void) #ifdef BIGDECIMAL_USE_VP_TEST_METHODS rb_define_method(rb_cBigDecimal, "vpdivd", BigDecimal_vpdivd, 2); + rb_define_method(rb_cBigDecimal, "vpdivd_newton", BigDecimal_vpdivd_newton, 2); + rb_define_method(rb_cBigDecimal, "newton_raphson_inverse", BigDecimal_newton_raphson_inverse, 1); rb_define_method(rb_cBigDecimal, "vpmult", BigDecimal_vpmult, 1); #ifdef USE_NTT_MULTIPLICATION rb_define_method(rb_cBigDecimal, "nttmult", BigDecimal_nttmult, 1); @@ -5067,6 +5082,14 @@ VpDivd(Real *c, Real *r, Real *a, Real *b) if (word_a > word_r || word_b + word_c - 2 >= word_r) goto space_error; +#ifdef USE_NTT_MULTIPLICATION + // Newton-Raphson division requires multiplication to be faster than O(n^2) + if (word_c >= NEWTON_RAPHSON_DIVISION_THRESHOLD && word_b >= NEWTON_RAPHSON_DIVISION_THRESHOLD) { + VpDivdNewton(c, r, a, b); + goto Exit; + } +#endif + for (i = 0; i < word_a; ++i) r->frac[i] = a->frac[i]; for (i = word_a; i < word_r; ++i) r->frac[i] = 0; for (i = 0; i < word_c; ++i) c->frac[i] = 0; diff --git a/ext/bigdecimal/bigdecimal.h b/ext/bigdecimal/bigdecimal.h index 82c88a2a..71ddb21f 100644 --- a/ext/bigdecimal/bigdecimal.h +++ b/ext/bigdecimal/bigdecimal.h @@ -188,6 +188,11 @@ typedef struct { DECDIG frac[FLEXIBLE_ARRAY_SIZE]; /* Array of fraction part. */ } Real; +typedef struct { + VALUE bigdecimal; + Real *real; +} BDVALUE; + /* * ------------------ * EXPORTables. @@ -232,10 +237,31 @@ VP_EXPORT int VpActiveRound(Real *y, Real *x, unsigned short f, ssize_t il); VP_EXPORT int VpMidRound(Real *y, unsigned short f, ssize_t nf); VP_EXPORT int VpLeftRound(Real *y, unsigned short f, ssize_t nf); VP_EXPORT void VpFrac(Real *y, Real *x); +VP_EXPORT int AddExponent(Real *a, SIGNED_VALUE n); /* VP constants */ VP_EXPORT Real *VpOne(void); +/* + * **** BigDecimal part **** + */ +VP_EXPORT VALUE BigDecimal_lt(VALUE self, VALUE r); +VP_EXPORT VALUE BigDecimal_ge(VALUE self, VALUE r); +VP_EXPORT VALUE BigDecimal_exponent(VALUE self); +VP_EXPORT VALUE BigDecimal_fix(VALUE self); +VP_EXPORT VALUE BigDecimal_frac(VALUE self); +VP_EXPORT VALUE BigDecimal_add(VALUE self, VALUE b); +VP_EXPORT VALUE BigDecimal_sub(VALUE self, VALUE b); +VP_EXPORT VALUE BigDecimal_mult(VALUE self, VALUE b); +VP_EXPORT VALUE BigDecimal_add2(VALUE self, VALUE b, VALUE n); +VP_EXPORT VALUE BigDecimal_sub2(VALUE self, VALUE b, VALUE n); +VP_EXPORT VALUE BigDecimal_mult2(VALUE self, VALUE b, VALUE n); +VP_EXPORT VALUE BigDecimal_split(VALUE self); +VP_EXPORT VALUE BigDecimal_decimal_shift(VALUE self, VALUE v); +VP_EXPORT inline BDVALUE GetBDValueMust(VALUE v); +VP_EXPORT inline BDVALUE rbd_allocate_struct_zero_wrap(int sign, size_t const digits); +#define NewZeroWrap rbd_allocate_struct_zero_wrap + /* * ------------------ * MACRO definitions. diff --git a/ext/bigdecimal/div.h b/ext/bigdecimal/div.h new file mode 100644 index 00000000..e6dd89c9 --- /dev/null +++ b/ext/bigdecimal/div.h @@ -0,0 +1,192 @@ +// Calculate the inverse of x using the Newton-Raphson method. +static VALUE +newton_raphson_inverse(VALUE x, size_t prec) { + BDVALUE bdone = NewZeroWrap(1, 1); + VpSetOne(bdone.real); + VALUE one = bdone.bigdecimal; + + // Initial approximation in 2 digits + BDVALUE bdx = GetBDValueMust(x); + BDVALUE inv0 = NewZeroWrap(1, 2 * BIGDECIMAL_COMPONENT_FIGURES); + VpSetOne(inv0.real); + DECDIG_DBL numerator = (DECDIG_DBL)BIGDECIMAL_BASE * 100; + DECDIG_DBL denominator = (DECDIG_DBL)bdx.real->frac[0] * 100 + (DECDIG_DBL)(bdx.real->Prec >= 2 ? bdx.real->frac[1] : 0) * 100 / BIGDECIMAL_BASE; + inv0.real->frac[0] = (DECDIG)(numerator / denominator); + inv0.real->frac[1] = (DECDIG)((numerator % denominator) * (BIGDECIMAL_BASE / 100) / denominator * 100); + inv0.real->Prec = 2; + inv0.real->exponent = 1 - bdx.real->exponent; + VpNmlz(inv0.real); + RB_GC_GUARD(bdx.bigdecimal); + VALUE inv = inv0.bigdecimal; + + int bl = 1; + while (((size_t)1 << bl) < prec) bl++; + + for (int i = bl; i >= 0; i--) { + size_t n = (prec >> i) + 2; + if (n > prec) n = prec; + // Newton-Raphson iteration: inv_next = inv + inv * (1 - x * inv) + VALUE one_minus_x_inv = BigDecimal_sub2( + one, + BigDecimal_mult(BigDecimal_mult2(x, one, SIZET2NUM(n + 1)), inv), + SIZET2NUM(SIZET2NUM(n / 2)) + ); + inv = BigDecimal_add2( + inv, + BigDecimal_mult(inv, one_minus_x_inv), + SIZET2NUM(n) + ); + } + return inv; +} + +// Calculates divmod by multiplying approximate reciprocal of y +static void +divmod_by_inv_mul(VALUE x, VALUE y, VALUE inv, VALUE *res_div, VALUE *res_mod) { + VALUE div = BigDecimal_fix(BigDecimal_mult(x, inv)); + VALUE mod = BigDecimal_sub(x, BigDecimal_mult(div, y)); + while (RTEST(BigDecimal_lt(mod, INT2FIX(0)))) { + mod = BigDecimal_add(mod, y); + div = BigDecimal_sub(div, INT2FIX(1)); + } + while (RTEST(BigDecimal_ge(mod, y))) { + mod = BigDecimal_sub(mod, y); + div = BigDecimal_add(div, INT2FIX(1)); + } + *res_div = div; + *res_mod = mod; +} + +static void +slice_copy(DECDIG *dest, Real *src, size_t rshift, size_t length) { + ssize_t start = src->exponent - rshift - length; + if (start >= (ssize_t)src->Prec) return; + if (start < 0) { + dest -= start; + length += start; + start = 0; + } + size_t max_length = src->Prec - start; + memcpy(dest, src->frac + start, Min(length, max_length) * sizeof(DECDIG)); +} + +/* Calculates divmod using Newton-Raphson method. + * x and y must be a BigDecimal representing an integer value. + * + * To calculate with low cost, we need to split x into blocks and perform divmod for each block. + * x_digits = remaining_digits(<= y_digits) + block_digits * num_blocks + * + * Example: + * xxx_xxxxx_xxxxx_xxxxx(18 digits) / yyyyy(5 digits) + * remaining_digits = 3, block_digits = 5, num_blocks = 3 + * repeating xxxxx_xxxxxx.divmod(yyyyy) calculation 3 times. + * + * In each divmod step, dividend is at most (y_digits + block_digits) digits and divisor is y_digits digits. + * Reciprocal of y needs block_digits + 1 precision. + */ +static void +divmod_newton(VALUE x, VALUE y, VALUE *div_out, VALUE *mod_out) { + size_t x_digits = NUM2SIZET(BigDecimal_exponent(x)); + size_t y_digits = NUM2SIZET(BigDecimal_exponent(y)); + if (x_digits <= y_digits) x_digits = y_digits + 1; + + size_t n = x_digits / y_digits; + size_t block_figs = (x_digits - y_digits) / n / BIGDECIMAL_COMPONENT_FIGURES + 1; + size_t block_digits = block_figs * BIGDECIMAL_COMPONENT_FIGURES; + size_t num_blocks = (x_digits - y_digits + block_digits - 1) / block_digits; + size_t y_figs = (y_digits - 1) / BIGDECIMAL_COMPONENT_FIGURES + 1; + VALUE yinv = newton_raphson_inverse(y, block_digits + 1); + + BDVALUE divident = NewZeroWrap(1, BIGDECIMAL_COMPONENT_FIGURES * (y_figs + block_figs)); + BDVALUE div_result = NewZeroWrap(1, BIGDECIMAL_COMPONENT_FIGURES * (num_blocks * block_figs + 1)); + BDVALUE bdx = GetBDValueMust(x); + + VALUE mod = BigDecimal_fix(BigDecimal_decimal_shift(x, SSIZET2NUM(-num_blocks * block_digits))); + for (ssize_t i = num_blocks - 1; i >= 0; i--) { + memset(divident.real->frac, 0, (y_figs + block_figs) * sizeof(DECDIG)); + + BDVALUE bdmod = GetBDValueMust(mod); + slice_copy(divident.real->frac, bdmod.real, 0, y_figs); + slice_copy(divident.real->frac + y_figs, bdx.real, i * block_figs, block_figs); + RB_GC_GUARD(bdmod.bigdecimal); + + VpSetSign(divident.real, 1); + divident.real->exponent = y_figs + block_figs; + divident.real->Prec = y_figs + block_figs; + VpNmlz(divident.real); + + VALUE div; + divmod_by_inv_mul(divident.bigdecimal, y, yinv, &div, &mod); + BDVALUE bddiv = GetBDValueMust(div); + slice_copy(div_result.real->frac + (num_blocks - i - 1) * block_figs, bddiv.real, 0, block_figs + 1); + RB_GC_GUARD(bddiv.bigdecimal); + } + VpSetSign(div_result.real, 1); + div_result.real->exponent = num_blocks * block_figs + 1; + div_result.real->Prec = num_blocks * block_figs + 1; + VpNmlz(div_result.real); + RB_GC_GUARD(bdx.bigdecimal); + RB_GC_GUARD(divident.bigdecimal); + RB_GC_GUARD(div_result.bigdecimal); + *div_out = div_result.bigdecimal; + *mod_out = mod; +} + +static VALUE +VpDivdNewtonInner(VALUE args_ptr) +{ + Real **args = (Real**)args_ptr; + Real *c = args[0], *r = args[1], *a = args[2], *b = args[3]; + BDVALUE a2, b2, c2, r2; + VALUE div, mod, a2_frac = Qnil; + size_t div_prec = c->MaxPrec - 1; + size_t base_prec = b->Prec; + + a2 = NewZeroWrap(1, a->Prec * BIGDECIMAL_COMPONENT_FIGURES); + b2 = NewZeroWrap(1, b->Prec * BIGDECIMAL_COMPONENT_FIGURES); + VpAsgn(a2.real, a, 1); + VpAsgn(b2.real, b, 1); + VpSetSign(a2.real, 1); + VpSetSign(b2.real, 1); + a2.real->exponent = base_prec + div_prec; + b2.real->exponent = base_prec; + + if ((ssize_t)a2.real->Prec > a2.real->exponent) { + a2_frac = BigDecimal_frac(a2.bigdecimal); + VpMidRound(a2.real, VP_ROUND_DOWN, 0); + } + divmod_newton(a2.bigdecimal, b2.bigdecimal, &div, &mod); + if (a2_frac != Qnil) mod = BigDecimal_add(mod, a2_frac); + + c2 = GetBDValueMust(div); + r2 = GetBDValueMust(mod); + VpAsgn(c, c2.real, VpGetSign(a) * VpGetSign(b)); + VpAsgn(r, r2.real, VpGetSign(a)); + AddExponent(c, a->exponent); + AddExponent(c, -b->exponent); + AddExponent(c, -div_prec); + AddExponent(r, a->exponent); + AddExponent(r, -base_prec - div_prec); + RB_GC_GUARD(a2.bigdecimal); + RB_GC_GUARD(a2.bigdecimal); + RB_GC_GUARD(c2.bigdecimal); + RB_GC_GUARD(r2.bigdecimal); + return Qnil; +} + +static VALUE +ensure_restore_prec_limit(VALUE limit) +{ + VpSetPrecLimit(NUM2SIZET(limit)); + return Qnil; +} + +static void +VpDivdNewton(Real *c, Real *r, Real *a, Real *b) +{ + Real *args[4] = {c, r, a, b}; + size_t pl = VpGetPrecLimit(); + VpSetPrecLimit(0); + // Ensure restoring prec limit because some methods used in VpDivdNewtonInner may raise an exception + rb_ensure(VpDivdNewtonInner, (VALUE)args, ensure_restore_prec_limit, SIZET2NUM(pl)); +} diff --git a/test/bigdecimal/test_vp_operation.rb b/test/bigdecimal/test_vp_operation.rb index fd6774a4..389aba77 100644 --- a/test/bigdecimal/test_vp_operation.rb +++ b/test/bigdecimal/test_vp_operation.rb @@ -34,6 +34,59 @@ def test_nttmult end end + def test_newton_inverse + xs = [BigDecimal(3), BigDecimal('123e50'), BigDecimal('13' * 44), BigDecimal('17' * 45), BigDecimal('19' * 46)] + %i[up half_up down].each do |rounding_mode| + BigDecimal.save_rounding_mode do + BigDecimal.mode(BigDecimal::ROUND_MODE, rounding_mode) + [*1..32, 50, 100, 200, 300].each do |prec| + xs.each do |x| + inv = x.newton_raphson_inverse(prec) + assert_in_delta(1, x * inv, BigDecimal("1e#{1 - prec}")) + + high_precision_inv = inv * (2 - x * inv) + expected_inv = high_precision_inv.mult(1, prec) + last_digit = BigDecimal("1e#{expected_inv.exponent - prec}") + assert_include([expected_inv - last_digit, expected_inv, expected_inv + last_digit], inv) + end + end + end + end + end + + def test_not_affected_by_limit + x_int = 123**135 + y_int = 135**123 + xy_int = x_int * y_int + mod_int = 111**111 + x = BigDecimal(x_int) + y = BigDecimal(y_int) + xy = BigDecimal(xy_int) + mod = BigDecimal(mod_int) + z = BigDecimal(xy_int + mod_int) + BigDecimal.save_limit do + BigDecimal.limit 3 + assert_equal(xy, x.vpmult(y)) + assert_equal(3, BigDecimal.limit) + if ntt_mult_available? + assert_equal(xy, x.nttmult(y)) + assert_equal(3, BigDecimal.limit) + end + + prec = (z.exponent - 1) / BASE_FIG - (y.exponent - 1) / BASE_FIG + 1 + assert_equal([x, mod], z.vpdivd(y, prec)) + assert_equal(3, BigDecimal.limit) + assert_equal([x, mod], z.vpdivd_newton(y, prec)) + assert_equal(3, BigDecimal.limit) + end + end + + def assert_vpdivd_equal(expected_divmod, x_y_n) + x, *args = x_y_n + assert_equal(expected_divmod, x.vpdivd(*args)) + assert_equal(expected_divmod, x.vpdivd_newton(*args)) + end + def test_vpdivd # a[0] > b[0] # XXXX_YYYY_ZZZZ / 1111 #=> 000X_000Y_000Z @@ -44,11 +97,11 @@ def test_vpdivd d3 = BigDecimal("4e#{BASE_FIG * 2}") + d2 d4 = BigDecimal("5e#{BASE_FIG}") + d3 d5 = BigDecimal(6) + d4 - assert_equal([d1, x1 - d1 * y], x1.vpdivd(y, 1)) - assert_equal([d2, x1 - d2 * y], x1.vpdivd(y, 2)) - assert_equal([d3, x1 - d3 * y], x1.vpdivd(y, 3)) - assert_equal([d4, x1 - d4 * y], x1.vpdivd(y, 4)) - assert_equal([d5, x1 - d5 * y], x1.vpdivd(y, 5)) + assert_vpdivd_equal([d1, x1 - d1 * y], [x1, y, 1]) + assert_vpdivd_equal([d2, x1 - d2 * y], [x1, y, 2]) + assert_vpdivd_equal([d3, x1 - d3 * y], [x1, y, 3]) + assert_vpdivd_equal([d4, x1 - d4 * y], [x1, y, 4]) + assert_vpdivd_equal([d5, x1 - d5 * y], [x1, y, 5]) # a[0] < b[0] # 00XX_XXYY_YYZZ_ZZ00 / 1111 #=> 0000_0X00_0Y00_0Z00 @@ -59,28 +112,28 @@ def test_vpdivd d3 = BigDecimal("4e#{2 * BASE_FIG + shift}") + d2 d4 = BigDecimal("5e#{BASE_FIG + shift}") + d3 d5 = BigDecimal("6e#{shift}") + d4 - assert_equal([0, x2], x2.vpdivd(y, 1)) - assert_equal([d1, x2 - d1 * y], x2.vpdivd(y, 2)) - assert_equal([d2, x2 - d2 * y], x2.vpdivd(y, 3)) - assert_equal([d3, x2 - d3 * y], x2.vpdivd(y, 4)) - assert_equal([d4, x2 - d4 * y], x2.vpdivd(y, 5)) - assert_equal([d5, x2 - d5 * y], x2.vpdivd(y, 6)) + assert_vpdivd_equal([0, x2], [x2, y, 1]) + assert_vpdivd_equal([d1, x2 - d1 * y], [x2, y, 2]) + assert_vpdivd_equal([d2, x2 - d2 * y], [x2, y, 3]) + assert_vpdivd_equal([d3, x2 - d3 * y], [x2, y, 4]) + assert_vpdivd_equal([d4, x2 - d4 * y], [x2, y, 5]) + assert_vpdivd_equal([d5, x2 - d5 * y], [x2, y, 6]) end def test_vpdivd_large_quotient_prec # 0001 / 0003 = 0000_3333_3333 - assert_equal([BigDecimal('0.' + '3' * BASE_FIG * 9), BigDecimal("1e-#{9 * BASE_FIG}")], BigDecimal(1).vpdivd(BigDecimal(3), 10)) + assert_vpdivd_equal([BigDecimal('0.' + '3' * BASE_FIG * 9), BigDecimal("1e-#{9 * BASE_FIG}")], [BigDecimal(1), BigDecimal(3), 10]) # 1000 / 0003 = 0333_3333_3333 - assert_equal([BigDecimal('3' * (BASE_FIG - 1) + '.' + '3' * BASE_FIG * 9), BigDecimal("1e-#{9 * BASE_FIG}")], BigDecimal(BASE / 10).vpdivd(BigDecimal(3), 10)) + assert_vpdivd_equal([BigDecimal('3' * (BASE_FIG - 1) + '.' + '3' * BASE_FIG * 9), BigDecimal("1e-#{9 * BASE_FIG}")], [BigDecimal(BASE / 10), BigDecimal(3), 10]) end def test_vpdivd_with_one x = BigDecimal('1234.2468000001234') - assert_equal([BigDecimal('1234'), BigDecimal('0.2468000001234')], x.vpdivd(BigDecimal(1), 1)) - assert_equal([BigDecimal('+1234.2468'), BigDecimal('+0.1234e-9')], (+x).vpdivd(BigDecimal(+1), 2)) - assert_equal([BigDecimal('-1234.2468'), BigDecimal('+0.1234e-9')], (+x).vpdivd(BigDecimal(-1), 2)) - assert_equal([BigDecimal('-1234.2468'), BigDecimal('-0.1234e-9')], (-x).vpdivd(BigDecimal(+1), 2)) - assert_equal([BigDecimal('+1234.2468'), BigDecimal('-0.1234e-9')], (-x).vpdivd(BigDecimal(-1), 2)) + assert_vpdivd_equal([BigDecimal('1234'), BigDecimal('0.2468000001234')], [x, BigDecimal(1), 1]) + assert_vpdivd_equal([BigDecimal('+1234.2468'), BigDecimal('+0.1234e-9')], [+x, BigDecimal(+1), 2]) + assert_vpdivd_equal([BigDecimal('-1234.2468'), BigDecimal('+0.1234e-9')], [+x, BigDecimal(-1), 2]) + assert_vpdivd_equal([BigDecimal('-1234.2468'), BigDecimal('-0.1234e-9')], [-x, BigDecimal(+1), 2]) + assert_vpdivd_equal([BigDecimal('+1234.2468'), BigDecimal('-0.1234e-9')], [-x, BigDecimal(-1), 2]) end def test_vpdivd_precisions @@ -92,7 +145,7 @@ def test_vpdivd_precisions yn = (y.digits.size + BASE_FIG - 1) / BASE_FIG base = BASE ** (n - xn + yn - 1) div = BigDecimal((x * base / y).to_i) / base - assert_equal([div, x - y * div], BigDecimal(x).vpdivd(y, n)) + assert_vpdivd_equal([div, x - y * div], [BigDecimal(x), BigDecimal(y), n]) end end end @@ -105,7 +158,7 @@ def test_vpdivd_borrow x = y * (3 * BASE**4 + a * BASE**3 + b * BASE**2 + c * BASE + d) / BASE div = BigDecimal(x * BASE / y) / BASE mod = BigDecimal(x) - div * y - assert_equal([div, mod], BigDecimal(x).vpdivd(BigDecimal(y), 5)) + assert_vpdivd_equal([div, mod], [BigDecimal(x), BigDecimal(y), 5]) end end end @@ -117,22 +170,22 @@ def test_vpdivd_large_prec_divisor divy1_1 = BigDecimal(2) divy2_1 = BigDecimal(1) divy2_2 = BigDecimal('1.' + '9' * BASE_FIG) - assert_equal([divy1_1, x - y1 * divy1_1], x.vpdivd(y1, 1)) - assert_equal([divy2_1, x - y2 * divy2_1], x.vpdivd(y2, 1)) - assert_equal([divy2_2, x - y2 * divy2_2], x.vpdivd(y2, 2)) + assert_vpdivd_equal([divy1_1, x - y1 * divy1_1], [x, y1, 1]) + assert_vpdivd_equal([divy2_1, x - y2 * divy2_1], [x, y2, 1]) + assert_vpdivd_equal([divy2_2, x - y2 * divy2_2], [x, y2, 2]) end def test_vpdivd_intermediate_zero if BASE_FIG == 9 x = BigDecimal('123456789.246913578000000000123456789') y = BigDecimal('123456789') - assert_equal([BigDecimal('1.000000002000000000000000001'), BigDecimal(0)], x.vpdivd(y, 4)) - assert_equal([BigDecimal('1.000000000049999999'), BigDecimal('1e-18')], BigDecimal("2.000000000099999999").vpdivd(2, 3)) + assert_vpdivd_equal([BigDecimal('1.000000002000000000000000001'), BigDecimal(0)], [x, y, 4]) + assert_vpdivd_equal([BigDecimal('1.000000000049999999'), BigDecimal('1e-18')], [BigDecimal("2.000000000099999999"), 2, 3]) else x = BigDecimal('1234.246800001234') y = BigDecimal('1234') - assert_equal([BigDecimal('1.000200000001'), BigDecimal(0)], x.vpdivd(y, 4)) - assert_equal([BigDecimal('1.00000499'), BigDecimal('1e-8')], BigDecimal("2.00000999").vpdivd(2, 3)) + assert_vpdivd_equal([BigDecimal('1.000200000001'), BigDecimal(0)], [x, y, 4]) + assert_vpdivd_equal([BigDecimal('1.00000499'), BigDecimal('1e-8')], [BigDecimal("2.00000999"), 2, 3]) end end end From 2d454f6e0860b6d02b72cdd03fef93ab1a3372da Mon Sep 17 00:00:00 2001 From: tompng Date: Tue, 16 Sep 2025 01:42:07 +0900 Subject: [PATCH 3/5] Improve taylor series calculation of exp and sin by binary splitting method exp and sin becomes orders of magnitude faster. To make log and atan also fast, log and atan now depends on exp and sin. log(x): solve exp(y)-x=0 by Newton's method atan(x): solve tan(y)-x=0 by Newton's method --- lib/bigdecimal.rb | 133 ++++++++++++++++++-------------- lib/bigdecimal/math.rb | 90 +++++++++++---------- test/bigdecimal/test_bigmath.rb | 34 ++++---- 3 files changed, 145 insertions(+), 112 deletions(-) diff --git a/lib/bigdecimal.rb b/lib/bigdecimal.rb index 12250ce9..998087d8 100644 --- a/lib/bigdecimal.rb +++ b/lib/bigdecimal.rb @@ -60,6 +60,46 @@ def self.nan_computation_result # :nodoc: end BigDecimal::NAN end + + # Iteration for Newton's method with increasing precision + def self.newton_loop(prec, initial_precision: BigDecimal.double_fig / 2, safe_margin: 2) # :nodoc: + precs = [] + while prec > initial_precision + precs << prec + prec = (precs.last + 1) / 2 + safe_margin + end + precs.reverse_each do |p| + yield p + end + end + + # Calculates Math.log(x.to_f) considering large or small exponent + def self.float_log(x) # :nodoc: + Math.log(x._decimal_shift(-x.exponent).to_f) + x.exponent * Math.log(10) + end + + # Calculating Taylor series sum using binary splitting method + # Calculates f(x) = (x/d0)*(1+(x/d1)*(1+(x/d2)*(1+(x/d3)*(1+...)))) + # x.n_significant_digits or ds.size must be small to be performant. + def self.taylor_sum_binary_splitting(x, ds, prec) # :nodoc: + fs = ds.map {|d| [0, BigDecimal(d)] } + # fs = [[a0, a1], [b0, b1], [c0, c1], ...] + # f(x) = a0/a1+(x/a1)*(1+b0/b1+(x/b1)*(1+c0/c1+(x/c1)*(1+d0/d1+(x/d1)*(1+...)))) + while fs.size > 1 + # Merge two adjacent fractions + # from: (1 + a0/a1 + x/a1 * (1 + b0/b1 + x/b1 * rest)) + # to: (1 + (a0*b1+x*(b0+b1))/(a1*b1) + (x*x)/(a1*b1) * rest) + xn = xn ? xn.mult(xn, prec) : x + fs = fs.each_slice(2).map do |(a, b)| + b ||= [0, BigDecimal(1)._decimal_shift([xn.exponent, 0].max + 2)] + [ + (a[0] * b[1]).add(xn * (b[0] + b[1]), prec), + a[1].mult(b[1], prec) + ] + end + end + BigDecimal(fs[0][0]).div(fs[0][1], prec) + end end # call-seq: @@ -226,9 +266,7 @@ def sqrt(prec) ex = exponent / 2 x = _decimal_shift(-2 * ex) y = BigDecimal(Math.sqrt(x.to_f), 0) - precs = [prec + BigDecimal.double_fig] - precs << 2 + precs.last / 2 while precs.last > BigDecimal.double_fig - precs.reverse_each do |p| + Internal.newton_loop(prec + BigDecimal.double_fig) do |p| y = y.add(x.div(y, p), p).div(2, p) end y._decimal_shift(ex).mult(1, prec) @@ -264,59 +302,32 @@ def log(x, prec) return BigDecimal(0) if x == 1 prec2 = prec + BigDecimal.double_fig - BigDecimal.save_limit do - BigDecimal.limit(0) - if x > 10 || x < 0.1 - log10 = log(BigDecimal(10), prec2) - exponent = x.exponent - x = x._decimal_shift(-exponent) - if x < 0.3 - x *= 10 - exponent -= 1 - end - return (log10 * exponent).add(log(x, prec2), prec) - end - - x_minus_one_exponent = (x - 1).exponent - # log(x) = log(sqrt(sqrt(sqrt(sqrt(x))))) * 2**sqrt_steps - sqrt_steps = [Integer.sqrt(prec2) + 3 * x_minus_one_exponent, 0].max - - lg2 = 0.3010299956639812 - sqrt_prec = prec2 + [-x_minus_one_exponent, 0].max + (sqrt_steps * lg2).ceil - - sqrt_steps.times do - x = x.sqrt(sqrt_prec) - end - - # Taylor series for log(x) around 1 - # log(x) = -log((1 + X) / (1 - X)) where X = (x - 1) / (x + 1) - # log(x) = 2 * (X + X**3 / 3 + X**5 / 5 + X**7 / 7 + ...) - x = (x - 1).div(x + 1, sqrt_prec) - y = x - x2 = x.mult(x, prec2) - 1.step do |i| - n = prec2 + x.exponent - y.exponent + x2.exponent - break if n <= 0 || x.zero? - x = x.mult(x2.round(n - x2.exponent), n) - y = y.add(x.div(2 * i + 1, n), prec2) - end + if x < 0.1 || x > 10 + exponent = (3 * x).exponent - 1 + x = x._decimal_shift(-exponent) + return log(10, prec2).mult(exponent, prec2).add(log(x, prec2), prec) + end - y.mult(2 ** (sqrt_steps + 1), prec) + # Solve exp(y) - x = 0 with Newton's method + # Repeat: y -= (exp(y) - x) / exp(y) + y = BigDecimal(BigDecimal::Internal.float_log(x), 0) + exp_additional_prec = [-(x - 1).exponent, 0].max + BigDecimal::Internal.newton_loop(prec2) do |p| + expy = exp(y, p + exp_additional_prec) + y = y.sub(expy.sub(x, p).div(expy, p), p) end + y.mult(1, prec) end - # Taylor series for exp(x) around 0 - private_class_method def _exp_taylor(x, prec) # :nodoc: - xn = BigDecimal(1) - y = BigDecimal(1) - 1.step do |i| - n = prec + xn.exponent - break if n <= 0 || xn.zero? - xn = xn.mult(x, n).div(i, n) - y = y.add(xn, prec) - end - y + private_class_method def _exp_binary_splitting(x, prec) # :nodoc: + return BigDecimal(1) if x.zero? + # Find k that satisfies x**k / k! < 10**(-prec) + log10 = Math.log(10) + logx = BigDecimal::Internal.float_log(x.abs) + step = (1..).bsearch { |k| Math.lgamma(k + 1)[0] - k * logx > prec * log10 } + # exp(x)-1 = x*(1+x/2*(1+x/3*(1+x/4*(1+x/5*(1+...))))) + 1 + BigDecimal::Internal.taylor_sum_binary_splitting(x, [*1..step], prec) end # call-seq: @@ -341,11 +352,21 @@ def exp(x, prec) prec2 = prec + BigDecimal.double_fig + cnt x = x._decimal_shift(-cnt) - # Calculation of exp(small_prec) is fast because calculation of x**n is fast - # Calculation of exp(small_abs) converges fast. - # exp(x) = exp(small_prec_part + small_abs_part) = exp(small_prec_part) * exp(small_abs_part) - x_small_prec = x.round(Integer.sqrt(prec2)) - y = _exp_taylor(x_small_prec, prec2).mult(_exp_taylor(x.sub(x_small_prec, prec2), prec2), prec2) + # Decimal form of bit-burst algorithm + # Calculate exp(x.xxxxxxxxxxxxxxxx) as + # exp(x.xx) * exp(0.00xx) * exp(0.0000xxxx) * exp(0.00000000xxxxxxxx) + x = x.mult(1, prec2) + n = 2 + y = BigDecimal(1) + BigDecimal.save_limit do + BigDecimal.limit(0) + while x != 0 do + partial_x = x.truncate(n) + x -= partial_x + y = y.mult(_exp_binary_splitting(partial_x, prec2), prec2) + n *= 2 + end + end # calculate exp(x * 10**cnt) from exp(x) # exp(x * 10**k) = exp(x * 10**(k - 1)) ** 10 diff --git a/lib/bigdecimal/math.rb b/lib/bigdecimal/math.rb index d0d49cb8..b7085c14 100644 --- a/lib/bigdecimal/math.rb +++ b/lib/bigdecimal/math.rb @@ -88,6 +88,37 @@ def sqrt(x, prec) end end + private_class_method def _sin_binary_splitting(x, prec) # :nodoc: + return x if x.zero? + x2 = x.mult(x, prec) + # Find k that satisfies x2**k / (2k+1)! < 10**(-prec) + log10 = Math.log(10) + logx = BigDecimal::Internal.float_log(x.abs) + step = (1..).bsearch { |k| Math.lgamma(2 * k + 1)[0] - 2 * k * logx > prec * log10 } + # Construct denominator sequence for binary splitting + # sin(x) = x*(1-x2/(2*3)*(1-x2/(4*5)*(1-x2/(6*7)*(1-x2/(8*9)*(1-...))))) + ds = (1..step).map {|i| -(2 * i) * (2 * i + 1) } + x.mult(1 + BigDecimal::Internal.taylor_sum_binary_splitting(x2, ds, prec), prec) + end + + private_class_method def _sin_around_zero(x, prec) # :nodoc: + # Divide x into several parts + # sin(x.xxxxxxxx...) = sin(x.xx + 0.00xx + 0.0000xxxx + ...) + # Calculate sin of each part and restore sin(0.xxxxxxxx...) using addition theorem. + sin = BigDecimal(0) + cos = BigDecimal(1) + n = 2 + while x != 0 do + partial_x = x.truncate(n) + x -= partial_x + s = _sin_binary_splitting(partial_x, prec) + c = (1 - s * s).sqrt(prec) + sin, cos = (sin * c).add(cos * s, prec), (cos * c).sub(sin * s, prec) + n *= 2 + end + sin.clamp(BigDecimal(-1), BigDecimal(1)) + end + # call-seq: # cbrt(decimal, numeric) -> BigDecimal # @@ -150,26 +181,9 @@ def sin(x, prec) prec = BigDecimal::Internal.coerce_validate_prec(prec, :sin) x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :sin) return BigDecimal::Internal.nan_computation_result if x.infinite? || x.nan? - n = prec + BigDecimal.double_fig - one = BigDecimal("1") - two = BigDecimal("2") + n = prec + BigDecimal.double_fig sign, x = _sin_periodic_reduction(x, n) - x1 = x - x2 = x.mult(x,n) - y = x - d = y - i = one - z = one - while d.nonzero? && ((m = n - (y.exponent - d.exponent).abs) > 0) - m = BigDecimal.double_fig if m < BigDecimal.double_fig - x1 = -x2.mult(x1,n) - i += two - z *= (i-one) * i - d = x1.div(z,m) - y += d - end - y = BigDecimal("1") if y > 1 - y.mult(sign, prec) + _sin_around_zero(x, n).mult(sign, prec) end # call-seq: @@ -187,8 +201,9 @@ def cos(x, prec) prec = BigDecimal::Internal.coerce_validate_prec(prec, :cos) x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :cos) return BigDecimal::Internal.nan_computation_result if x.infinite? || x.nan? - sign, x = _sin_periodic_reduction(x, prec + BigDecimal.double_fig, add_half_pi: true) - sign * sin(x, prec) + n = prec + BigDecimal.double_fig + sign, x = _sin_periodic_reduction(x, n, add_half_pi: true) + _sin_around_zero(x, n).mult(sign, prec) end # call-seq: @@ -277,28 +292,21 @@ def atan(x, prec) x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :atan) return BigDecimal::Internal.nan_computation_result if x.nan? n = prec + BigDecimal.double_fig - pi = PI(n) + return PI(n).div(2 * x.infinite?, prec) if x.infinite? + x = -x if neg = x < 0 - return pi.div(neg ? -2 : 2, prec) if x.infinite? - return pi.div(neg ? -4 : 4, prec) if x.round(prec) == 1 - x = BigDecimal("1").div(x, n) if inv = x > 1 - x = (-1 + sqrt(1 + x.mult(x, n), n)).div(x, n) if dbl = x > 0.5 - y = x - d = y - t = x - r = BigDecimal("3") - x2 = x.mult(x,n) - while d.nonzero? && ((m = n - (y.exponent - d.exponent).abs) > 0) - m = BigDecimal.double_fig if m < BigDecimal.double_fig - t = -t.mult(x2,n) - d = t.div(r,m) - y += d - r += 2 + x = BigDecimal(1).div(x, n) if inv = x < -1 || x > 1 + + # Solve tan(y) - x = 0 with Newton's method + # Repeat: y -= (tan(y) - x) * cos(y)**2 + y = BigDecimal(Math.atan(x.to_f), 0) + BigDecimal::Internal.newton_loop(n) do |p| + s = sin(y, p) + c = (1 - s * s).sqrt(p) + y = y.sub(c * (s.sub(c * x.mult(1, p), p)), p) end - y *= 2 if dbl - y = pi / 2 - y if inv - y = -y if neg - y.mult(1, prec) + y = PI(n) / 2 - y if inv + y.mult(neg ? -1 : 1, prec) end # call-seq: diff --git a/test/bigdecimal/test_bigmath.rb b/test/bigdecimal/test_bigmath.rb index 39dee611..1a00ca14 100644 --- a/test/bigdecimal/test_bigmath.rb +++ b/test/bigdecimal/test_bigmath.rb @@ -182,8 +182,13 @@ def test_sin assert_converge_in_precision {|n| sin(BigDecimal("1e-30"), n) } assert_converge_in_precision {|n| sin(BigDecimal(PI(50)), n) } assert_converge_in_precision {|n| sin(BigDecimal(PI(50) * 100), n) } - assert_operator(sin(PI(30) / 2, 30), :<=, 1) - assert_operator(sin(-PI(30) / 2, 30), :>=, -1) + [:up, :down].each do |mode| + BigDecimal.save_rounding_mode do + BigDecimal.mode(BigDecimal::ROUND_MODE, mode) + assert_operator(sin(PI(30) / 2, 30), :<=, 1) + assert_operator(sin(-PI(30) / 2, 30), :>=, -1) + end + end end def test_cos @@ -205,8 +210,13 @@ def test_cos assert_converge_in_precision {|n| cos(BigDecimal("1e50"), n) } assert_converge_in_precision {|n| cos(BigDecimal(PI(50) / 2), n) } assert_converge_in_precision {|n| cos(BigDecimal(PI(50) * 201 / 2), n) } - assert_operator(cos(PI(30), 30), :>=, -1) - assert_operator(cos(PI(30) * 2, 30), :<=, 1) + [:up, :down].each do |mode| + BigDecimal.save_rounding_mode do + BigDecimal.mode(BigDecimal::ROUND_MODE, mode) + assert_operator(cos(PI(30), 30), :>=, -1) + assert_operator(cos(PI(30) * 2, 30), :<=, 1) + end + end end def test_tan @@ -388,26 +398,20 @@ def test_exp def test_log assert_equal(0, log(BigDecimal("1.0"), 10)) - assert_in_epsilon(Math.log(10)*1000, log(BigDecimal("1e1000"), 10)) + assert_in_epsilon(1000 * Math.log(10), log(BigDecimal("1e1000"), 10)) + assert_in_epsilon(19999999999999 * Math.log(10), log(BigDecimal("1E19999999999999"), 10)) + assert_in_epsilon(-19999999999999 * Math.log(10), log(BigDecimal("1E-19999999999999"), 10)) assert_in_exact_precision( BigDecimal("2.3025850929940456840179914546843642076011014886287729760333279009675726096773524802359972050895982983419677840422862"), log(BigDecimal("10"), 100), 100 ) assert_converge_in_precision {|n| log(BigDecimal("2"), n) } - assert_converge_in_precision {|n| log(BigDecimal("1e-30") + 1, n) } - assert_converge_in_precision {|n| log(BigDecimal("1e-30"), n) } + assert_converge_in_precision {|n| log(1 + SQRT2 * BigDecimal("1e-30"), n) } + assert_converge_in_precision {|n| log(SQRT2 * BigDecimal("1e-30"), n) } assert_converge_in_precision {|n| log(BigDecimal("1e30"), n) } assert_converge_in_precision {|n| log(SQRT2, n) } assert_raise(Math::DomainError) {log(BigDecimal("-0.1"), 10)} - begin - x = BigDecimal("1E19999999999999") - rescue FloatDomainError - else - unless x.infinite? - assert_in_epsilon(Math.log(10) * 19999999999999, BigMath.log(x, 10)) - end - end end def test_log2 From a3ea7d6fe016b2a78acadf63e17d4a3ac18e4e86 Mon Sep 17 00:00:00 2001 From: tompng Date: Fri, 19 Sep 2025 19:54:19 +0900 Subject: [PATCH 4/5] Drop Ruby 2.5 support bsearch for endless range is only available in ruby >= 2.6 --- .github/workflows/ci.yml | 2 +- bigdecimal.gemspec | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ae9a3b06..a4f7f5da 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: uses: ruby/actions/.github/workflows/ruby_versions.yml@master with: engine: cruby-truffleruby - min_version: 2.5 + min_version: 2.6 versions: '["debug"]' host: diff --git a/bigdecimal.gemspec b/bigdecimal.gemspec index 6b20ac08..774fd223 100644 --- a/bigdecimal.gemspec +++ b/bigdecimal.gemspec @@ -53,7 +53,7 @@ Gem::Specification.new do |s| ] end - s.required_ruby_version = Gem::Requirement.new(">= 2.5.0") + s.required_ruby_version = Gem::Requirement.new(">= 2.6.0") s.metadata["changelog_uri"] = s.homepage + "/blob/master/CHANGES.md" end From ace2d325157b829353c0e57a17c0304165e20ba7 Mon Sep 17 00:00:00 2001 From: tompng Date: Wed, 26 Nov 2025 14:21:01 +0900 Subject: [PATCH 5/5] BigMath.erf and BigMath.erfc with bit burst algorithm --- lib/bigdecimal/math.rb | 336 +++++++++++++++++++++++++++++--- test/bigdecimal/test_bigmath.rb | 61 ++++++ 2 files changed, 366 insertions(+), 31 deletions(-) diff --git a/lib/bigdecimal/math.rb b/lib/bigdecimal/math.rb index b7085c14..89bf9e12 100644 --- a/lib/bigdecimal/math.rb +++ b/lib/bigdecimal/math.rb @@ -576,6 +576,300 @@ def expm1(x, prec) exp_prec > 0 ? exp(x, exp_prec).sub(1, prec) : BigDecimal(-1) end + # call-seq: + # erf(decimal, numeric) -> BigDecimal + # + # Computes the error function of +decimal+ to the specified number of digits of + # precision, +numeric+. + # + # If +decimal+ is NaN, returns NaN. + # + # BigMath.erf(BigDecimal('1'), 32).to_s + # #=> "0.84270079294971486934122063508261e0" + # + def erf(x, prec) + prec = BigDecimal::Internal.coerce_validate_prec(prec, :erf) + x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :erf) + return BigDecimal::Internal.nan_computation_result if x.nan? + return BigDecimal(x.infinite?) if x.infinite? + return BigDecimal(0) if x == 0 + return -erf(-x, prec) if x < 0 + return BigDecimal(1) if x > 5000000000 # erf(5000000000) > 1 - 1e-10000000000000000000 + + if x > 8 + xf = x.to_f + log10_erfc = -xf ** 2 / Math.log(10) - Math.log10(xf * Math::PI ** 0.5) + erfc_prec = [prec + log10_erfc.ceil, 1].max + erfc = _erfc_bit_burst(x, erfc_prec + BigDecimal.double_fig) + return BigDecimal(1).sub(erfc, prec) if erfc + end + + _erf_bit_burst(x, prec + BigDecimal.double_fig).mult(1, prec) + end + + def erfc(x, prec) + prec = BigDecimal::Internal.coerce_validate_prec(prec, :erfc) + x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :erfc) + return BigDecimal::Internal.nan_computation_result if x.nan? + return BigDecimal(1 - x.infinite?) if x.infinite? + return BigDecimal(1).sub(erf(x, prec), prec) if x < 0 + return BigDecimal(0) if x > 5000000000 # erfc(5000000000) < 1e-10000000000000000000 (underflow) + + if x >= 8 + y = _erfc_bit_burst(x, prec + BigDecimal.double_fig) + return y.mult(1, prec) if y + end + + # erfc(x) = 1 - erf(x) < exp(-x**2)/x/sqrt(pi) + # Precision of erf(x) needs about log10(exp(-x**2)) extra digits + log10 = 2.302585092994046 + high_prec = prec + BigDecimal.double_fig + (x.to_f**2 / log10).ceil + BigDecimal(1).sub(_erf_bit_burst(x, high_prec), prec) + end + + # Calculates erf(x) using bit-burst algorithm. + private_class_method def _erf_bit_burst(x, prec) + x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :erf) + prec = BigDecimal::Internal.coerce_validate_prec(prec, :erf) + + return BigDecimal(0) if x > 5000000000 # erfc underflows + x = x.mult(1, [prec - (x.to_f**2/Math.log(10)).floor, 1].max) + + calculated_x = BigDecimal(0) + erf_exp2 = BigDecimal(0) + digits = 8 + xf = x.to_f + scale = 2 * exp(-x.mult(x, prec), prec).div(PI(prec).sqrt(prec), prec) + + until x.zero? + partial = x.truncate(digits) + digits *= 2 + next if partial.zero? + + erf_exp2 = _erf_exp2_binary_splitting(partial, calculated_x, erf_exp2, prec) + calculated_x += partial + x -= partial + end + erf_exp2.mult(scale, prec) + end + + # Calculates erfc(x) using bit-burst algorithm. + private_class_method def _erfc_bit_burst(x, prec) # :nodoc: + digits = (x.exponent + 1) * 40 + + calculated_x = x.truncate(digits) + f = _erfc_exp2_asymptotic_binary_splitting(calculated_x, prec) + return unless f + + scale = 2 * exp(-x.mult(x, prec), prec).div(PI(prec).sqrt(prec), prec) + x -= calculated_x + + until x.zero? + digits *= 2 + partial = x.truncate(digits) + next if partial.zero? + + f = _erfc_exp2_inv_inv_binary_splitting(partial, calculated_x, f, prec) + calculated_x += partial + x -= partial + end + f.mult(scale, prec) + end + + # Matrix multiplication for binary splitting method in erf/erfc calculation + private_class_method def _bs_matrix_mult(m1, m2, size, prec) # :nodoc: + (size * size).times.map do |i| + size.times.map do |k| + m1[i / size * size + k].mult(m2[size * k + i % size], prec) + end.reduce {|a, b| a.add(b, prec) } + end + end + + # Matrix/Vector weighted sum for binary splitting method in erf/erfc calculation + private_class_method def _bs_weighted_sum(m1, w1, m2, w2, prec) # :nodoc: + m1.zip(m2).map {|v1, v2| (v1 * w1).add(v2 * w2, prec) } + end + + # Calculates Taylor expansion of erf(x+a)*exp((x+a)**2)*sqrt(pi)/2 with binary splitting method. + private_class_method def _erf_exp2_binary_splitting(x, a, f_a, prec) # :nodoc: + cexponent = Math.log10([2 * a, Math.sqrt(2)].max.to_f) + log10(x.abs, 10) + log10f = Math.log(10) + + steps = BigDecimal.save_exception_mode do + BigDecimal.mode(BigDecimal::EXCEPTION_UNDERFLOW, false) + (2..).bsearch do |n| + x.to_f ** 2 < n && n * cexponent + Math.lgamma(n / 2)[0] / log10f + n * Math.log10(2) - Math.lgamma(n - 1)[0] / log10f < -prec + x.to_f**2 / log10f + end + end + + if a == 0 + denominators = (steps / 2).times.map {|i| 2 * i + 3 } + return x.mult(1 + BigDecimal::Internal.taylor_sum_binary_splitting(2 * x * x, denominators, prec), prec) + end + + # First, calculate a matrix that represents the sum of the Taylor series: + # SumMatrix = (((((...+I)x*M4+I)*x*M3+I)*M2*x+I)*M1*x+I) + # Where Mi is a 2x2 matrix that generates the next coefficients of Taylor series: + # Vector(c4, c5) = M4*M3*M2*M1*Vector(c0, c1) + # And then calculates: + # SumMatrix * Vector(c0, c1) = Vector(c0+c1*x+c2*x**2+..., _) + # In this binary splitting method, adjacent two operations are combined into one repeatedly. + # ((...) * x * A + B) / C is the form of each operation. A and B are 2x2 matrices, C is a scalar. + zero = BigDecimal(0) + two = BigDecimal(2) + two_a = two * a + operations = steps.times.map do |i| + n = BigDecimal(2 + i) + [[zero, n, two, two_a], [n, zero, zero, n], n] + end + + while operations.size > 1 + xpow = xpow ? xpow.mult(xpow, prec) : x.mult(1, prec) + operations = operations.each_slice(2).map do |op1, op2| + # Combine two operations into one: + # (((Remaining * x * A2 + B2) / C2) * x * A1 + B1) / C1 + # ((Remaining * (x*x) * (A2*A1) + (x*B2*A1+B1*C2)) / (C1*C2) + # Therefore, combined operation can be represented as: + # Anext = A2 * A1 + # Bnext = x * B2 * A1 + B1 * C2 + # Cnext = C1 * C2 + # xnext = x * x + a1, b1, c1 = op1 + a2, b2, c2 = op2 || [[zero] * 4, [zero] * 4, BigDecimal(1)] + [ + _bs_matrix_mult(a2, a1, 2, prec), + _bs_weighted_sum(_bs_matrix_mult(b2, a1, 2, prec), xpow, b1, c2, prec), + c1.mult(c2, prec), + ] + end + end + _, sum_matrix, denominator = operations.first + (sum_matrix[1] + f_a * (2 * a * sum_matrix[1] + sum_matrix[0])).div(denominator, prec) + end + + # Calculates asymptotic expansion of erfc(x)*exp(x**2)*sqrt(pi)/2 with binary splitting method + private_class_method def _erfc_exp2_asymptotic_binary_splitting(x, prec) # :nodoc: + # Let f(x) = erfc(x)*sqrt(pi)*exp(x**2)/2 + # f(x) satisfies the following differential equation: + # 2*x*f(x) = f'(x) + 1 + # From the above equation, we can derive the following asymptotic expansion: + # f(x) = (0..kmax).sum { (-1)**k * (2*k)! / 4**k / k! / x**(2*k)) } / x + + # This asymptotic expansion does not converge. + # But if there is a k that satisfies (2*k)! / 4**k / k! / x**(2*k) < 10**(-prec), + # It is enough to calculate erfc within the given precision. + # Using Stirling's approximation, we can simplify this condition to: + # sqrt(2)/2 + k*log(k) - k - 2*k*log(x) < -prec*log(10) + # and the left side is minimized when k = x**2. + xf = x.to_f + kmax = (1..(xf ** 2).floor).bsearch do |k| + Math.log(2) / 2 + k * Math.log(k) - k - 2 * k * Math.log(xf) < -prec * Math.log(10) + end + return unless kmax + + # Convert asymptotic expansion to nested form: + # 1 + a/x + a*b/x/x + a*b*c/x/x/x + a*b*c/x/x/x*rest + # = 1 + (a/x) * (1 + (b/x) * (1 + (c/x) * (1 + rest))) + # + # And calculate it with binary splitting: + # (a1/d + b1/d * (a2/d + b2/d * (rest))) + # = ((a1*d+b1*a2)/(d*d) + b1*b2/(d*denominator) * (rest))) + denominator = x.mult(x, prec).mult(2, prec) + fractions = (1..kmax).map do |k| + [denominator, BigDecimal(1 - 2 * k)] + end + while fractions.size > 1 + fractions = fractions.each_slice(2).map do |fraction1, fraction2| + a1, b1 = fraction1 + a2, b2 = fraction2 || [BigDecimal(0), denominator] + [ + a1.mult(denominator, prec).add(b1.mult(a2, prec), prec), + b1.mult(b2, prec), + ] + end + denominator = denominator.mult(denominator, prec) + end + sum = fractions[0][0].add(fractions[0][1], prec).div(denominator, prec) + sum.div(x, prec) / 2 + end + + # Calculates f(1/(a + x)) where f(x) = (sqrt(pi)/2) * exp(1/x**2) * erfc(1/x) + # f(1/(a+x)) = f(1/a - x/(a*(a+x))) + private_class_method def _erfc_exp2_inv_inv_binary_splitting(x, a, f_inva, prec) + return f_inva if x.zero? + # f(x) satisfies the following differential equation: + # (1/a+w)**3*f'(1/a+w) + 2*f(1/a+w) = 1/a + w + # From the above equation, we can derive the following Taylor expansion around x=a: + # Coefficients: f(1/a + w) = c0 + c1*w + c2*w**2 + c3*w**3 + ... + # Constraints: + # (w**3 + 3*w**2/a + 3*w/a**2 + 1/a**3) * (c1 + 2*c2*w + 3*c3*w**2 + 4*c4*w**3 + ...) + # + 2 * (c0 + c1*w + c2*w**2 + c3*w**3 + ...) = 1/a + w + # Recurrence relations: + # c0 = f(1/a) + # c1 = a**2 - 2*c0*a**3 + # c2 = (a**3 - 3*c1*a - 2*c1*a**3) / 2 + # c3 = -(3*c1*a**2 + 6*c2*a + 2*c2*a**3) / 3 + # c(n) = -((n-3)*c(n-3)*a**3 + 3*(n-2)*c(n-2)*a**2 + 3*(n-1)*c(n-1)*a + 2*c(n-1)*a**3) / n + + aa = a.mult(a, prec) + aaa = aa.mult(a, prec) + c0 = f_inva + c1 = (aa - 2 * c0 * aaa).mult(1, prec) + c2 = (aaa - 3 * c1 * a - 2 * c1 * aaa).div(2, prec) + + # Estimate the number of steps needed to achieve the required precision + low_prec = 16 + w = x.div(a.mult(a + x, low_prec), low_prec) + wpow = w.mult(w, low_prec) + cm3, cm2, cm1 = [c0, c1, c2].map {|v| v.mult(1, low_prec) } + a_low, aa_low, aaa_low = [a, aa, aaa].map {|v| v.mult(1, low_prec) } + step = (3..).find do |n| + wpow = wpow.mult(w, low_prec) + cn = -((n - 3) * cm3 * aaa_low + 3 * aa_low * (n - 2) * cm2 + 3 * a_low * (n - 1) * cm1 + 2 * cm1 * aaa_low).div(n, low_prec) + cm3, cm2, cm1 = cm2, cm1, cn + cn.mult(wpow, low_prec).exponent < -prec + end + + # Let M(n) be a 3x3 matrix that transforms (c(n),c(n+1),c(n+2)) to (c(n-1),c(n),c(n+1)) + # Mn = | 0 1 0 | + # | 0 0 1 | + # | -(n-3)*aaa/n -3*(n-2)*aa/n -2*aaa-3*(n-1)*a/n | + # Vector(c6,c7,c8) = M6*M5*M4*M3*M2*M1 * Vector(c0,c1,c2) + # Vector(c0+c1*y/z+c2*(y/z)**2+..., _, _) = (((... + I)*M3*y/z + I)*M2*y/z + I)*M1*y/z + I) * Vector(c2, c1, c0) + # Perform binary splitting on this nested parenthesized calculation by using the following formula: + # (((...)*A2*y/z + B2)/D2 * A1*y/z + B1)/D1 = (((...)*(A2*A1)*(y*y)/z + (B2*A1*y+z*D2*B1)) / (D1*D2*z) + # where A_n, Bn are matrices and Dn are scalars + + zero = BigDecimal(0) + operations = (3..step + 2).map do |n| + bign = BigDecimal(n) + [ + [ + zero, bign, zero, + zero, zero, bign, + BigDecimal(-(n - 3) * aaa), -3 * (n - 2) * aa, -2 * aaa - 3 * (n - 1) * a + ], + [bign, zero, zero, zero, bign, zero, zero, zero, bign], + bign + ] + end + + z = a.mult(a + x, prec) + while operations.size > 1 + y = y ? y.mult(y, prec) : -x.mult(1, prec) + operations = operations.each_slice(2).map do |op1, op2| + a1, b1, d1 = op1 + a2, b2, d2 = op2 || [[zero] * 9, [zero] * 9, BigDecimal(1)] + [ + _bs_matrix_mult(a2, a1, 3, prec), + _bs_weighted_sum(_bs_matrix_mult(b2, a1, 3, prec), y, b1, d2.mult(z, prec), prec), + d1.mult(d2, prec).mult(z, prec), + ] + end + end + _, sum_matrix, denominator = operations[0] + (sum_matrix[0] * c0 + sum_matrix[1] * c1 + sum_matrix[2] * c2).div(denominator, prec) + end # call-seq: # PI(numeric) -> BigDecimal @@ -588,38 +882,18 @@ def expm1(x, prec) # def PI(prec) prec = BigDecimal::Internal.coerce_validate_prec(prec, :PI) - n = prec + BigDecimal.double_fig - zero = BigDecimal("0") - one = BigDecimal("1") - two = BigDecimal("2") - - m25 = BigDecimal("-0.04") - m57121 = BigDecimal("-57121") - - pi = zero - - d = one - k = one - t = BigDecimal("-80") - while d.nonzero? && ((m = n - (pi.exponent - d.exponent).abs) > 0) - m = BigDecimal.double_fig if m < BigDecimal.double_fig - t = t*m25 - d = t.div(k,m) - k = k+two - pi = pi + d - end - - d = one - k = one - t = BigDecimal("956") - while d.nonzero? && ((m = n - (pi.exponent - d.exponent).abs) > 0) - m = BigDecimal.double_fig if m < BigDecimal.double_fig - t = t.div(m57121,n) - d = t.div(k,m) - pi = pi + d - k = k+two + n = prec + BigDecimal.double_fig + a = BigDecimal(1) + b = BigDecimal(0.5, 0).sqrt(n) + s = BigDecimal(0.25, 0) + t = 1 + while a != b && (a - b).exponent > 1 - n + c = (a - b).div(2, n) + a, b = (a + b).div(2, n), (a * b).sqrt(n) + s = s.sub(c * c * t, n) + t *= 2 end - pi.mult(1, prec) + (a * b).div(s, prec) end # call-seq: diff --git a/test/bigdecimal/test_bigmath.rb b/test/bigdecimal/test_bigmath.rb index 1a00ca14..9cb27e85 100644 --- a/test/bigdecimal/test_bigmath.rb +++ b/test/bigdecimal/test_bigmath.rb @@ -473,4 +473,65 @@ def test_expm1 assert_in_exact_precision(exp(BigDecimal("1.23e-10"), 120) - 1, expm1(BigDecimal("1.23e-10"), 100), 100) assert_in_exact_precision(exp(123, 120) - 1, expm1(BigDecimal("123"), 100), 100) end + + def test_erf + [-0.5, 0.1, 0.3, 2.1, 3.3].each do |x| + assert_in_epsilon(Math.erf(x), BigMath.erf(BigDecimal(x.to_s), N)) + end + assert_equal(1, BigMath.erf(PINF, N)) + assert_equal(-1, BigMath.erf(MINF, N)) + assert_equal(1, BigMath.erf(BigDecimal(1000), 100)) + assert_equal(-1, BigMath.erf(BigDecimal(-1000), 100)) + assert_not_equal(1, BigMath.erf(BigDecimal(10), 45)) + assert_not_equal(1, BigMath.erf(BigDecimal(15), 100)) + assert_equal(1, BigMath.erf(BigDecimal('1e400'), 10)) + assert_equal(-1, BigMath.erf(BigDecimal('-1e400'), 10)) + assert_equal( + BigDecimal("0.9953222650189527341620692563672529286108917970400600767383523262004372807199951773676290080196806805"), + BigMath.erf(BigDecimal("2"), 100) + ) + assert_converge_in_precision {|n| BigMath.erf(BigDecimal("1e-30"), n) } + assert_converge_in_precision {|n| BigMath.erf(BigDecimal("0.3"), n) } + assert_converge_in_precision {|n| BigMath.erf(SQRT2, n) } + end + + def test_erfc + [-0.5, 0.1, 0.3, 2.1, 3.3].each do |x| + assert_in_epsilon(Math.erfc(x), BigMath.erfc(BigDecimal(x.to_s), N)) + end + assert_equal(0, BigMath.erfc(PINF, N)) + assert_equal(2, BigMath.erfc(MINF, N)) + assert_equal(0, BigMath.erfc(BigDecimal('1e400'), 10)) + assert_equal(2, BigMath.erfc(BigDecimal('-1e400'), 10)) + + # erfc with taylor series + assert_equal( + BigDecimal("2.088487583762544757000786294957788611560818119321163727012213713938174695833440290610766384285723554e-45"), + BigMath.erfc(BigDecimal("10"), 100) + ) + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(0.3), n) } + assert_converge_in_precision {|n| BigMath.erfc(SQRT2, n) } + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(8), n) } + # erfc with asymptotic expansion + assert_equal( + BigDecimal("1.896961059966276509268278259713415434936907563929186183462834752900411805205111886605256690776760041e-697"), + BigMath.erfc(BigDecimal("40"), 100) + ) + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(30), n) } + assert_converge_in_precision {|n| BigMath.erfc(30 * SQRT2, n) } + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(50), n) } + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(60000), n) } + # Near crossover point between taylor series and asymptotic expansion around prec=150 + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(19.5), n) } + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(20.5), n) } + end + + def test_erf_erfc_consistency_large_prec + [BigDecimal(34.5), 34 + BigDecimal(4).div(7, 1200)].each do |x| + erf = BigMath.erf(x, 1200) # Calculated with taylor series of erf + erfc = BigMath.erfc(x, 400) # Calculated with asymptotic expansion + erfc2 = 1 - erf + assert_equal(erfc, erfc2.mult(1, 400)) + end + end end