diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ae9a3b06..a4f7f5da 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: uses: ruby/actions/.github/workflows/ruby_versions.yml@master with: engine: cruby-truffleruby - min_version: 2.5 + min_version: 2.6 versions: '["debug"]' host: diff --git a/bigdecimal.gemspec b/bigdecimal.gemspec index b6ef8fd9..774fd223 100644 --- a/bigdecimal.gemspec +++ b/bigdecimal.gemspec @@ -43,15 +43,17 @@ Gem::Specification.new do |s| ext/bigdecimal/bigdecimal.c ext/bigdecimal/bigdecimal.h ext/bigdecimal/bits.h + ext/bigdecimal/div.h ext/bigdecimal/feature.h ext/bigdecimal/missing.c ext/bigdecimal/missing.h + ext/bigdecimal/ntt.h ext/bigdecimal/missing/dtoa.c ext/bigdecimal/static_assert.h ] end - s.required_ruby_version = Gem::Requirement.new(">= 2.5.0") + s.required_ruby_version = Gem::Requirement.new(">= 2.6.0") s.metadata["changelog_uri"] = s.homepage + "/blob/master/CHANGES.md" end diff --git a/ext/bigdecimal/bigdecimal.c b/ext/bigdecimal/bigdecimal.c index d9247790..5c249270 100644 --- a/ext/bigdecimal/bigdecimal.c +++ b/ext/bigdecimal/bigdecimal.c @@ -29,10 +29,18 @@ #endif #include "bits.h" +#include "div.h" #include "static_assert.h" #define BIGDECIMAL_VERSION "3.3.1" +#if SIZEOF_DECDIG == 4 +#define USE_NTT_MULTIPLICATION 1 +#include "ntt.h" +#define NTT_MULTIPLICATION_THRESHOLD 100 +#define NEWTON_RAPHSON_DIVISION_THRESHOLD 200 +#endif + /* #define ENABLE_NUMERIC_STRING */ #define SIGNED_VALUE_MAX INTPTR_MAX @@ -75,11 +83,6 @@ static struct { uint8_t mode; } rbd_rounding_modes[RBD_NUM_ROUNDING_MODES]; -typedef struct { - VALUE bigdecimal; - Real *real; -} BDVALUE; - typedef struct { VALUE bigdecimal_or_nil; Real *real_or_null; @@ -207,7 +210,6 @@ rbd_allocate_struct_zero(int sign, size_t const digits) static unsigned short VpGetException(void); static void VpSetException(unsigned short f); static void VpCheckException(Real *p, bool always); -static int AddExponent(Real *a, SIGNED_VALUE n); static VALUE CheckGetValue(BDVALUE v); static void VpInternalRound(Real *c, size_t ixDigit, DECDIG vPrev, DECDIG v); static int VpLimitRound(Real *c, size_t ixDigit); @@ -1112,9 +1114,6 @@ BigDecimal_check_num(Real *p) VpCheckException(p, true); } -static VALUE BigDecimal_fix(VALUE self); -static VALUE BigDecimal_split(VALUE self); - /* Returns the value as an Integer. * * If the BigDecimal is infinity or NaN, raises FloatDomainError. @@ -3257,19 +3256,39 @@ BigDecimal_literal(const char *str) #ifdef BIGDECIMAL_USE_VP_TEST_METHODS VALUE -BigDecimal_vpdivd(VALUE self, VALUE r, VALUE cprec) { - BDVALUE a,b,c,d; +BigDecimal_vpdivd_generic(VALUE self, VALUE r, VALUE cprec, void (*vpdivd_func)(Real*, Real*, Real*, Real*)) { + BDVALUE a, b, c, d; size_t cn = NUM2INT(cprec); a = GetBDValueMust(self); b = GetBDValueMust(r); c = NewZeroWrap(1, cn * BASE_FIG); d = NewZeroWrap(1, VPDIVD_REM_PREC(a.real, b.real, c.real) * BASE_FIG); - VpDivd(c.real, d.real, a.real, b.real); + vpdivd_func(c.real, d.real, a.real, b.real); RB_GC_GUARD(a.bigdecimal); RB_GC_GUARD(b.bigdecimal); return rb_assoc_new(c.bigdecimal, d.bigdecimal); } +void +VpDivdNormal(Real *c, Real *r, Real *a, Real *b) { + VpDivd(c, r, a, b); +} + +VALUE +BigDecimal_vpdivd(VALUE self, VALUE r, VALUE cprec) { + return BigDecimal_vpdivd_generic(self, r, cprec, VpDivdNormal); +} + +VALUE +BigDecimal_vpdivd_newton(VALUE self, VALUE r, VALUE cprec) { + return BigDecimal_vpdivd_generic(self, r, cprec, VpDivdNewton); +} + +VALUE +BigDecimal_newton_raphson_inverse(VALUE self, VALUE prec) { + return newton_raphson_inverse(self, NUM2SIZET(prec)); +} + VALUE BigDecimal_vpmult(VALUE self, VALUE v) { BDVALUE a,b,c; @@ -3281,6 +3300,25 @@ BigDecimal_vpmult(VALUE self, VALUE v) { RB_GC_GUARD(b.bigdecimal); return c.bigdecimal; } + +#if SIZEOF_DECDIG == 4 +VALUE +BigDecimal_nttmult(VALUE self, VALUE v) { + BDVALUE a,b,c; + a = GetBDValueMust(self); + b = GetBDValueMust(v); + c = NewZeroWrap(1, VPMULT_RESULT_PREC(a.real, b.real) * BASE_FIG); + ntt_multiply(a.real->Prec, b.real->Prec, a.real->frac, b.real->frac, c.real->frac); + VpSetSign(c.real, a.real->sign * b.real->sign); + c.real->exponent = a.real->exponent + b.real->exponent; + c.real->Prec = a.real->Prec + b.real->Prec; + VpNmlz(c.real); + RB_GC_GUARD(a.bigdecimal); + RB_GC_GUARD(b.bigdecimal); + return c.bigdecimal; +} +#endif + #endif /* BIGDECIMAL_USE_VP_TEST_METHODS */ /* Document-class: BigDecimal @@ -3652,7 +3690,12 @@ Init_bigdecimal(void) #ifdef BIGDECIMAL_USE_VP_TEST_METHODS rb_define_method(rb_cBigDecimal, "vpdivd", BigDecimal_vpdivd, 2); + rb_define_method(rb_cBigDecimal, "vpdivd_newton", BigDecimal_vpdivd_newton, 2); + rb_define_method(rb_cBigDecimal, "newton_raphson_inverse", BigDecimal_newton_raphson_inverse, 1); rb_define_method(rb_cBigDecimal, "vpmult", BigDecimal_vpmult, 1); +#ifdef USE_NTT_MULTIPLICATION + rb_define_method(rb_cBigDecimal, "nttmult", BigDecimal_nttmult, 1); +#endif #endif /* BIGDECIMAL_USE_VP_TEST_METHODS */ #define ROUNDING_MODE(i, name, value) \ @@ -4935,6 +4978,15 @@ VpMult(Real *c, Real *a, Real *b) c->exponent = a->exponent; /* set exponent */ VpSetSign(c, VpGetSign(a) * VpGetSign(b)); /* set sign */ if (!AddExponent(c, b->exponent)) return 0; + +#ifdef USE_NTT_MULTIPLICATION + if (b->Prec >= NTT_MULTIPLICATION_THRESHOLD) { + ntt_multiply((uint32_t)a->Prec, (uint32_t)b->Prec, a->frac, b->frac, c->frac); + c->Prec = a->Prec + b->Prec; + goto Cleanup; + } +#endif + carry = 0; nc = ind_c = MxIndAB; memset(c->frac, 0, (nc + 1) * sizeof(DECDIG)); /* Initialize c */ @@ -4981,6 +5033,8 @@ VpMult(Real *c, Real *a, Real *b) } } } + +Cleanup: VpNmlz(c); Exit: @@ -5028,6 +5082,14 @@ VpDivd(Real *c, Real *r, Real *a, Real *b) if (word_a > word_r || word_b + word_c - 2 >= word_r) goto space_error; +#ifdef USE_NTT_MULTIPLICATION + // Newton-Raphson division requires multiplication to be faster than O(n^2) + if (word_c >= NEWTON_RAPHSON_DIVISION_THRESHOLD && word_b >= NEWTON_RAPHSON_DIVISION_THRESHOLD) { + VpDivdNewton(c, r, a, b); + goto Exit; + } +#endif + for (i = 0; i < word_a; ++i) r->frac[i] = a->frac[i]; for (i = word_a; i < word_r; ++i) r->frac[i] = 0; for (i = 0; i < word_c; ++i) c->frac[i] = 0; diff --git a/ext/bigdecimal/bigdecimal.h b/ext/bigdecimal/bigdecimal.h index 82c88a2a..71ddb21f 100644 --- a/ext/bigdecimal/bigdecimal.h +++ b/ext/bigdecimal/bigdecimal.h @@ -188,6 +188,11 @@ typedef struct { DECDIG frac[FLEXIBLE_ARRAY_SIZE]; /* Array of fraction part. */ } Real; +typedef struct { + VALUE bigdecimal; + Real *real; +} BDVALUE; + /* * ------------------ * EXPORTables. @@ -232,10 +237,31 @@ VP_EXPORT int VpActiveRound(Real *y, Real *x, unsigned short f, ssize_t il); VP_EXPORT int VpMidRound(Real *y, unsigned short f, ssize_t nf); VP_EXPORT int VpLeftRound(Real *y, unsigned short f, ssize_t nf); VP_EXPORT void VpFrac(Real *y, Real *x); +VP_EXPORT int AddExponent(Real *a, SIGNED_VALUE n); /* VP constants */ VP_EXPORT Real *VpOne(void); +/* + * **** BigDecimal part **** + */ +VP_EXPORT VALUE BigDecimal_lt(VALUE self, VALUE r); +VP_EXPORT VALUE BigDecimal_ge(VALUE self, VALUE r); +VP_EXPORT VALUE BigDecimal_exponent(VALUE self); +VP_EXPORT VALUE BigDecimal_fix(VALUE self); +VP_EXPORT VALUE BigDecimal_frac(VALUE self); +VP_EXPORT VALUE BigDecimal_add(VALUE self, VALUE b); +VP_EXPORT VALUE BigDecimal_sub(VALUE self, VALUE b); +VP_EXPORT VALUE BigDecimal_mult(VALUE self, VALUE b); +VP_EXPORT VALUE BigDecimal_add2(VALUE self, VALUE b, VALUE n); +VP_EXPORT VALUE BigDecimal_sub2(VALUE self, VALUE b, VALUE n); +VP_EXPORT VALUE BigDecimal_mult2(VALUE self, VALUE b, VALUE n); +VP_EXPORT VALUE BigDecimal_split(VALUE self); +VP_EXPORT VALUE BigDecimal_decimal_shift(VALUE self, VALUE v); +VP_EXPORT inline BDVALUE GetBDValueMust(VALUE v); +VP_EXPORT inline BDVALUE rbd_allocate_struct_zero_wrap(int sign, size_t const digits); +#define NewZeroWrap rbd_allocate_struct_zero_wrap + /* * ------------------ * MACRO definitions. diff --git a/ext/bigdecimal/div.h b/ext/bigdecimal/div.h new file mode 100644 index 00000000..e6dd89c9 --- /dev/null +++ b/ext/bigdecimal/div.h @@ -0,0 +1,192 @@ +// Calculate the inverse of x using the Newton-Raphson method. +static VALUE +newton_raphson_inverse(VALUE x, size_t prec) { + BDVALUE bdone = NewZeroWrap(1, 1); + VpSetOne(bdone.real); + VALUE one = bdone.bigdecimal; + + // Initial approximation in 2 digits + BDVALUE bdx = GetBDValueMust(x); + BDVALUE inv0 = NewZeroWrap(1, 2 * BIGDECIMAL_COMPONENT_FIGURES); + VpSetOne(inv0.real); + DECDIG_DBL numerator = (DECDIG_DBL)BIGDECIMAL_BASE * 100; + DECDIG_DBL denominator = (DECDIG_DBL)bdx.real->frac[0] * 100 + (DECDIG_DBL)(bdx.real->Prec >= 2 ? bdx.real->frac[1] : 0) * 100 / BIGDECIMAL_BASE; + inv0.real->frac[0] = (DECDIG)(numerator / denominator); + inv0.real->frac[1] = (DECDIG)((numerator % denominator) * (BIGDECIMAL_BASE / 100) / denominator * 100); + inv0.real->Prec = 2; + inv0.real->exponent = 1 - bdx.real->exponent; + VpNmlz(inv0.real); + RB_GC_GUARD(bdx.bigdecimal); + VALUE inv = inv0.bigdecimal; + + int bl = 1; + while (((size_t)1 << bl) < prec) bl++; + + for (int i = bl; i >= 0; i--) { + size_t n = (prec >> i) + 2; + if (n > prec) n = prec; + // Newton-Raphson iteration: inv_next = inv + inv * (1 - x * inv) + VALUE one_minus_x_inv = BigDecimal_sub2( + one, + BigDecimal_mult(BigDecimal_mult2(x, one, SIZET2NUM(n + 1)), inv), + SIZET2NUM(SIZET2NUM(n / 2)) + ); + inv = BigDecimal_add2( + inv, + BigDecimal_mult(inv, one_minus_x_inv), + SIZET2NUM(n) + ); + } + return inv; +} + +// Calculates divmod by multiplying approximate reciprocal of y +static void +divmod_by_inv_mul(VALUE x, VALUE y, VALUE inv, VALUE *res_div, VALUE *res_mod) { + VALUE div = BigDecimal_fix(BigDecimal_mult(x, inv)); + VALUE mod = BigDecimal_sub(x, BigDecimal_mult(div, y)); + while (RTEST(BigDecimal_lt(mod, INT2FIX(0)))) { + mod = BigDecimal_add(mod, y); + div = BigDecimal_sub(div, INT2FIX(1)); + } + while (RTEST(BigDecimal_ge(mod, y))) { + mod = BigDecimal_sub(mod, y); + div = BigDecimal_add(div, INT2FIX(1)); + } + *res_div = div; + *res_mod = mod; +} + +static void +slice_copy(DECDIG *dest, Real *src, size_t rshift, size_t length) { + ssize_t start = src->exponent - rshift - length; + if (start >= (ssize_t)src->Prec) return; + if (start < 0) { + dest -= start; + length += start; + start = 0; + } + size_t max_length = src->Prec - start; + memcpy(dest, src->frac + start, Min(length, max_length) * sizeof(DECDIG)); +} + +/* Calculates divmod using Newton-Raphson method. + * x and y must be a BigDecimal representing an integer value. + * + * To calculate with low cost, we need to split x into blocks and perform divmod for each block. + * x_digits = remaining_digits(<= y_digits) + block_digits * num_blocks + * + * Example: + * xxx_xxxxx_xxxxx_xxxxx(18 digits) / yyyyy(5 digits) + * remaining_digits = 3, block_digits = 5, num_blocks = 3 + * repeating xxxxx_xxxxxx.divmod(yyyyy) calculation 3 times. + * + * In each divmod step, dividend is at most (y_digits + block_digits) digits and divisor is y_digits digits. + * Reciprocal of y needs block_digits + 1 precision. + */ +static void +divmod_newton(VALUE x, VALUE y, VALUE *div_out, VALUE *mod_out) { + size_t x_digits = NUM2SIZET(BigDecimal_exponent(x)); + size_t y_digits = NUM2SIZET(BigDecimal_exponent(y)); + if (x_digits <= y_digits) x_digits = y_digits + 1; + + size_t n = x_digits / y_digits; + size_t block_figs = (x_digits - y_digits) / n / BIGDECIMAL_COMPONENT_FIGURES + 1; + size_t block_digits = block_figs * BIGDECIMAL_COMPONENT_FIGURES; + size_t num_blocks = (x_digits - y_digits + block_digits - 1) / block_digits; + size_t y_figs = (y_digits - 1) / BIGDECIMAL_COMPONENT_FIGURES + 1; + VALUE yinv = newton_raphson_inverse(y, block_digits + 1); + + BDVALUE divident = NewZeroWrap(1, BIGDECIMAL_COMPONENT_FIGURES * (y_figs + block_figs)); + BDVALUE div_result = NewZeroWrap(1, BIGDECIMAL_COMPONENT_FIGURES * (num_blocks * block_figs + 1)); + BDVALUE bdx = GetBDValueMust(x); + + VALUE mod = BigDecimal_fix(BigDecimal_decimal_shift(x, SSIZET2NUM(-num_blocks * block_digits))); + for (ssize_t i = num_blocks - 1; i >= 0; i--) { + memset(divident.real->frac, 0, (y_figs + block_figs) * sizeof(DECDIG)); + + BDVALUE bdmod = GetBDValueMust(mod); + slice_copy(divident.real->frac, bdmod.real, 0, y_figs); + slice_copy(divident.real->frac + y_figs, bdx.real, i * block_figs, block_figs); + RB_GC_GUARD(bdmod.bigdecimal); + + VpSetSign(divident.real, 1); + divident.real->exponent = y_figs + block_figs; + divident.real->Prec = y_figs + block_figs; + VpNmlz(divident.real); + + VALUE div; + divmod_by_inv_mul(divident.bigdecimal, y, yinv, &div, &mod); + BDVALUE bddiv = GetBDValueMust(div); + slice_copy(div_result.real->frac + (num_blocks - i - 1) * block_figs, bddiv.real, 0, block_figs + 1); + RB_GC_GUARD(bddiv.bigdecimal); + } + VpSetSign(div_result.real, 1); + div_result.real->exponent = num_blocks * block_figs + 1; + div_result.real->Prec = num_blocks * block_figs + 1; + VpNmlz(div_result.real); + RB_GC_GUARD(bdx.bigdecimal); + RB_GC_GUARD(divident.bigdecimal); + RB_GC_GUARD(div_result.bigdecimal); + *div_out = div_result.bigdecimal; + *mod_out = mod; +} + +static VALUE +VpDivdNewtonInner(VALUE args_ptr) +{ + Real **args = (Real**)args_ptr; + Real *c = args[0], *r = args[1], *a = args[2], *b = args[3]; + BDVALUE a2, b2, c2, r2; + VALUE div, mod, a2_frac = Qnil; + size_t div_prec = c->MaxPrec - 1; + size_t base_prec = b->Prec; + + a2 = NewZeroWrap(1, a->Prec * BIGDECIMAL_COMPONENT_FIGURES); + b2 = NewZeroWrap(1, b->Prec * BIGDECIMAL_COMPONENT_FIGURES); + VpAsgn(a2.real, a, 1); + VpAsgn(b2.real, b, 1); + VpSetSign(a2.real, 1); + VpSetSign(b2.real, 1); + a2.real->exponent = base_prec + div_prec; + b2.real->exponent = base_prec; + + if ((ssize_t)a2.real->Prec > a2.real->exponent) { + a2_frac = BigDecimal_frac(a2.bigdecimal); + VpMidRound(a2.real, VP_ROUND_DOWN, 0); + } + divmod_newton(a2.bigdecimal, b2.bigdecimal, &div, &mod); + if (a2_frac != Qnil) mod = BigDecimal_add(mod, a2_frac); + + c2 = GetBDValueMust(div); + r2 = GetBDValueMust(mod); + VpAsgn(c, c2.real, VpGetSign(a) * VpGetSign(b)); + VpAsgn(r, r2.real, VpGetSign(a)); + AddExponent(c, a->exponent); + AddExponent(c, -b->exponent); + AddExponent(c, -div_prec); + AddExponent(r, a->exponent); + AddExponent(r, -base_prec - div_prec); + RB_GC_GUARD(a2.bigdecimal); + RB_GC_GUARD(a2.bigdecimal); + RB_GC_GUARD(c2.bigdecimal); + RB_GC_GUARD(r2.bigdecimal); + return Qnil; +} + +static VALUE +ensure_restore_prec_limit(VALUE limit) +{ + VpSetPrecLimit(NUM2SIZET(limit)); + return Qnil; +} + +static void +VpDivdNewton(Real *c, Real *r, Real *a, Real *b) +{ + Real *args[4] = {c, r, a, b}; + size_t pl = VpGetPrecLimit(); + VpSetPrecLimit(0); + // Ensure restoring prec limit because some methods used in VpDivdNewtonInner may raise an exception + rb_ensure(VpDivdNewtonInner, (VALUE)args, ensure_restore_prec_limit, SIZET2NUM(pl)); +} diff --git a/ext/bigdecimal/ntt.h b/ext/bigdecimal/ntt.h new file mode 100644 index 00000000..941f23f7 --- /dev/null +++ b/ext/bigdecimal/ntt.h @@ -0,0 +1,191 @@ +// NTT (Number Theoretic Transform) implementation for BigDecimal multiplication + +#define NTT_PRIMITIVE_ROOT 17 +#define NTT_PRIME_BASE1 24 +#define NTT_PRIME_BASE2 26 +#define NTT_PRIME_BASE3 29 +#define NTT_PRIME_SHIFT 27 +#define NTT_PRIME1 (((uint32_t)NTT_PRIME_BASE1 << NTT_PRIME_SHIFT) | 1) +#define NTT_PRIME2 (((uint32_t)NTT_PRIME_BASE2 << NTT_PRIME_SHIFT) | 1) +#define NTT_PRIME3 (((uint32_t)NTT_PRIME_BASE3 << NTT_PRIME_SHIFT) | 1) +#define MAX_NTT32_BITS 27 +#define NTT_DECDIG_BASE 1000000000 + +// Calculates base**ex % mod +static uint32_t +mod_pow(uint32_t base, uint32_t ex, uint32_t mod) { + uint32_t res = 1; + uint32_t bit = 1; + while (true) { + if (ex & bit) { + ex ^= bit; + res = ((uint64_t)res * base) % mod; + } + if (!ex) break; + base = ((uint64_t)base * base) % mod; + bit <<= 1; + } + return res; +} + +// Recursively performs butterfly operations of NTT +static void +ntt_recursive(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int depth, uint32_t r, uint32_t prime) { + if (depth > 0) { + ntt_recursive(size_bits, input, tmp, output, depth - 1, ((uint64_t)r * r) % prime, prime); + } else { + tmp = input; + } + uint32_t size_half = (uint32_t)1 << (size_bits - 1); + uint32_t stride = (uint32_t)1 << (size_bits - depth - 1); + uint32_t n = size_half / stride; + uint32_t rn = 1, rm = prime - 1; + uint32_t idx = 0; + for (uint32_t i = 0; i < n; i++) { + uint32_t j = i * 2 * stride; + for (uint32_t k = 0; k < stride; k++, j++, idx++) { + uint32_t a = tmp[j], b = tmp[j + stride]; + output[idx] = (a + (uint64_t)rn * b) % prime; + output[idx + size_half] = (a + (uint64_t)rm * b) % prime; + } + rn = ((uint64_t)rn * r) % prime; + rm = ((uint64_t)rm * r) % prime; + } +} + +/* Perform NTT on input array. + * base, shift: Represent the prime number as (base << shift | 1) + * r_base: Primitive root of unity modulo prime + * size_bits: log2 of the size of the input array. Should be less or equal to shift + * input: input array of size (1 << size_bits) + */ +static void +ntt(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int r_base, int base, int shift, int dir) { + uint32_t size = (uint32_t)1 << size_bits; + uint32_t prime = ((uint32_t)base << shift) | 1; + + // rmax**(1 << shift) % prime == 1 + // r**size % prime == 1 + uint32_t rmax = mod_pow(r_base, base, prime); + uint32_t r = mod_pow(rmax, (uint32_t)1 << (shift - size_bits), prime); + + if (dir < 0) r = mod_pow(r, prime - 2, prime); + ntt_recursive(size_bits, input, output, tmp, size_bits - 1, r, prime); + if (dir < 0) { + uint32_t n_inv = mod_pow((uint32_t)size, prime - 2, prime); + for (uint32_t i = 0; i < size; i++) { + output[i] = ((uint64_t)output[i] * n_inv) % prime; + } + } +} + +/* Calculate c that satisfies: c % PRIME1 == mod1 && c % PRIME2 == mod2 && c % PRIME3 == mod3 + * c = (mod1 * 35002755423056150739595925972 + mod2 * 14584479687667766215746868453 + mod3 * 37919651490985126265126719818) % (PRIME1 * PRIME2 * PRIME3) + * Assume c <= 999999999**2*(1<<27) + */ +static inline void +mod_restore_prime_24_26_29_shift_27(uint32_t mod1, uint32_t mod2, uint32_t mod3, uint32_t *digits) { + // Use mixed radix notation to eliminate modulo by PRIME1 * PRIME2 * PRIME3 + // [DIG0, DIG1, DIG2] = DIG0 + DIG1 * PRIME1 + DIG2 * PRIME1 * PRIME2 + // DIG0: 0...PRIME1, DIG1: 0...PRIME2, DIG2: 0...PRIME3 + // 35002755423056150739595925972 = [1, 3489660916, 3113851359] + // 14584479687667766215746868453 = [0, 13, 1297437912] + // 37919651490985126265126719818 = [0, 0, 3373338954] + uint64_t c0 = mod1; + uint64_t c1 = (uint64_t)mod2 * 13 + (uint64_t)mod1 * 3489660916; + uint64_t c2 = (uint64_t)mod3 * 3373338954 % NTT_PRIME3 + (uint64_t)mod2 * 1297437912 % NTT_PRIME3 + (uint64_t)mod1 * 3113851359 % NTT_PRIME3; + c2 += c1 / NTT_PRIME2; + c1 %= NTT_PRIME2; + c2 %= NTT_PRIME3; + // Base conversion. c fits in 3 digits. + c1 += c2 % NTT_DECDIG_BASE * NTT_PRIME2; + c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1; + c1 /= NTT_DECDIG_BASE; + digits[0] = c0 % NTT_DECDIG_BASE; + c0 /= NTT_DECDIG_BASE; + c1 += c2 / NTT_DECDIG_BASE % NTT_DECDIG_BASE * NTT_PRIME2; + c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1; + c1 /= NTT_DECDIG_BASE; + digits[1] = c0 % NTT_DECDIG_BASE; + digits[2] = (uint32_t)(c0 / NTT_DECDIG_BASE + c1 % NTT_DECDIG_BASE * NTT_PRIME1); +} + +/* + * NTT multiplication + * Uses three NTTs with mod (24 << 27 | 1), (26 << 27 | 1), and (29 << 27 | 1) + */ +static void +ntt_multiply(size_t a_size, size_t b_size, uint32_t *a, uint32_t *b, uint32_t *c) { + if (a_size < b_size) { + ntt_multiply(b_size, a_size, b, a, c); + return; + } + + int b_bits = 0; + while (((uint32_t)1 << b_bits) < (uint32_t)b_size) b_bits++; + int ntt_size_bits = b_bits + 1; + if (ntt_size_bits > MAX_NTT32_BITS) { + rb_raise(rb_eArgError, "Multiply size too large"); + } + + // To calculate large_a * small_b faster, split into several batches. + uint32_t ntt_size = (uint32_t)1 << ntt_size_bits; + uint32_t batch_size = ntt_size - (uint32_t)b_size; + uint32_t batch_count = (uint32_t)((a_size + batch_size - 1) / batch_size); + + uint32_t *mem = ruby_xcalloc(sizeof(uint32_t), ntt_size * 9); + uint32_t *ntt1 = mem; + uint32_t *ntt2 = mem + ntt_size; + uint32_t *ntt3 = mem + ntt_size * 2; + uint32_t *tmp1 = mem + ntt_size * 3; + uint32_t *tmp2 = mem + ntt_size * 4; + uint32_t *tmp3 = mem + ntt_size * 5; + uint32_t *conv1 = mem + ntt_size * 6; + uint32_t *conv2 = mem + ntt_size * 7; + uint32_t *conv3 = mem + ntt_size * 8; + + // Calculate NTT for b in three primes. Result is reused for each batch of a. + memcpy(tmp1, b, b_size * sizeof(uint32_t)); + memset(tmp1 + b_size, 0, (ntt_size - b_size) * sizeof(uint32_t)); + ntt(ntt_size_bits, tmp1, ntt1, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1); + ntt(ntt_size_bits, tmp1, ntt2, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1); + ntt(ntt_size_bits, tmp1, ntt3, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1); + + memset(c, 0, (a_size + b_size) * sizeof(uint32_t)); + for (uint32_t idx = 0; idx < batch_count; idx++) { + uint32_t len = idx == batch_count - 1 ? (uint32_t)a_size - idx * batch_size : batch_size; + memcpy(tmp1, a + idx * batch_size, len * sizeof(uint32_t)); + memset(tmp1 + len, 0, (ntt_size - len) * sizeof(uint32_t)); + // Calculate convolution for this batch in three primes + ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1); + for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt1[i]) % NTT_PRIME1; + ntt(ntt_size_bits, tmp2, conv1, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, -1); + ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1); + for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt2[i]) % NTT_PRIME2; + ntt(ntt_size_bits, tmp2, conv2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, -1); + ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1); + for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt3[i]) % NTT_PRIME3; + ntt(ntt_size_bits, tmp2, conv3, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, -1); + + // Restore the original convolution value from three convolutions calculated in three primes. + // Each convolution value is maximum 999999999**2*(1<<27)/2 + for (uint32_t i = 0; i < ntt_size; i++) { + uint32_t dig[3]; + mod_restore_prime_24_26_29_shift_27(conv1[i], conv2[i], conv3[i], dig); + // Maximum values of dig[0], dig[1], and dig[2] are 999999999, 999999999 and 67108863 respectively + // Maximum overlapped sum (considering overlaps between 2 batches) is less than 4134217722 + // so this sum doesn't overflow uint32_t. + for (int j = 0; j < 3; j++) { + // Index check: if dig[j] is non-zero, assign index is within valid range. + if (dig[j]) c[idx * batch_size + i + 1 - j] += dig[j]; + } + } + } + uint32_t carry = 0; + for (int32_t i = (uint32_t)(a_size + b_size - 1); i >= 0; i--) { + uint32_t v = c[i] + carry; + c[i] = v % NTT_DECDIG_BASE; + carry = v / NTT_DECDIG_BASE; + } + ruby_xfree(mem); +} diff --git a/lib/bigdecimal.rb b/lib/bigdecimal.rb index 12250ce9..998087d8 100644 --- a/lib/bigdecimal.rb +++ b/lib/bigdecimal.rb @@ -60,6 +60,46 @@ def self.nan_computation_result # :nodoc: end BigDecimal::NAN end + + # Iteration for Newton's method with increasing precision + def self.newton_loop(prec, initial_precision: BigDecimal.double_fig / 2, safe_margin: 2) # :nodoc: + precs = [] + while prec > initial_precision + precs << prec + prec = (precs.last + 1) / 2 + safe_margin + end + precs.reverse_each do |p| + yield p + end + end + + # Calculates Math.log(x.to_f) considering large or small exponent + def self.float_log(x) # :nodoc: + Math.log(x._decimal_shift(-x.exponent).to_f) + x.exponent * Math.log(10) + end + + # Calculating Taylor series sum using binary splitting method + # Calculates f(x) = (x/d0)*(1+(x/d1)*(1+(x/d2)*(1+(x/d3)*(1+...)))) + # x.n_significant_digits or ds.size must be small to be performant. + def self.taylor_sum_binary_splitting(x, ds, prec) # :nodoc: + fs = ds.map {|d| [0, BigDecimal(d)] } + # fs = [[a0, a1], [b0, b1], [c0, c1], ...] + # f(x) = a0/a1+(x/a1)*(1+b0/b1+(x/b1)*(1+c0/c1+(x/c1)*(1+d0/d1+(x/d1)*(1+...)))) + while fs.size > 1 + # Merge two adjacent fractions + # from: (1 + a0/a1 + x/a1 * (1 + b0/b1 + x/b1 * rest)) + # to: (1 + (a0*b1+x*(b0+b1))/(a1*b1) + (x*x)/(a1*b1) * rest) + xn = xn ? xn.mult(xn, prec) : x + fs = fs.each_slice(2).map do |(a, b)| + b ||= [0, BigDecimal(1)._decimal_shift([xn.exponent, 0].max + 2)] + [ + (a[0] * b[1]).add(xn * (b[0] + b[1]), prec), + a[1].mult(b[1], prec) + ] + end + end + BigDecimal(fs[0][0]).div(fs[0][1], prec) + end end # call-seq: @@ -226,9 +266,7 @@ def sqrt(prec) ex = exponent / 2 x = _decimal_shift(-2 * ex) y = BigDecimal(Math.sqrt(x.to_f), 0) - precs = [prec + BigDecimal.double_fig] - precs << 2 + precs.last / 2 while precs.last > BigDecimal.double_fig - precs.reverse_each do |p| + Internal.newton_loop(prec + BigDecimal.double_fig) do |p| y = y.add(x.div(y, p), p).div(2, p) end y._decimal_shift(ex).mult(1, prec) @@ -264,59 +302,32 @@ def log(x, prec) return BigDecimal(0) if x == 1 prec2 = prec + BigDecimal.double_fig - BigDecimal.save_limit do - BigDecimal.limit(0) - if x > 10 || x < 0.1 - log10 = log(BigDecimal(10), prec2) - exponent = x.exponent - x = x._decimal_shift(-exponent) - if x < 0.3 - x *= 10 - exponent -= 1 - end - return (log10 * exponent).add(log(x, prec2), prec) - end - - x_minus_one_exponent = (x - 1).exponent - # log(x) = log(sqrt(sqrt(sqrt(sqrt(x))))) * 2**sqrt_steps - sqrt_steps = [Integer.sqrt(prec2) + 3 * x_minus_one_exponent, 0].max - - lg2 = 0.3010299956639812 - sqrt_prec = prec2 + [-x_minus_one_exponent, 0].max + (sqrt_steps * lg2).ceil - - sqrt_steps.times do - x = x.sqrt(sqrt_prec) - end - - # Taylor series for log(x) around 1 - # log(x) = -log((1 + X) / (1 - X)) where X = (x - 1) / (x + 1) - # log(x) = 2 * (X + X**3 / 3 + X**5 / 5 + X**7 / 7 + ...) - x = (x - 1).div(x + 1, sqrt_prec) - y = x - x2 = x.mult(x, prec2) - 1.step do |i| - n = prec2 + x.exponent - y.exponent + x2.exponent - break if n <= 0 || x.zero? - x = x.mult(x2.round(n - x2.exponent), n) - y = y.add(x.div(2 * i + 1, n), prec2) - end + if x < 0.1 || x > 10 + exponent = (3 * x).exponent - 1 + x = x._decimal_shift(-exponent) + return log(10, prec2).mult(exponent, prec2).add(log(x, prec2), prec) + end - y.mult(2 ** (sqrt_steps + 1), prec) + # Solve exp(y) - x = 0 with Newton's method + # Repeat: y -= (exp(y) - x) / exp(y) + y = BigDecimal(BigDecimal::Internal.float_log(x), 0) + exp_additional_prec = [-(x - 1).exponent, 0].max + BigDecimal::Internal.newton_loop(prec2) do |p| + expy = exp(y, p + exp_additional_prec) + y = y.sub(expy.sub(x, p).div(expy, p), p) end + y.mult(1, prec) end - # Taylor series for exp(x) around 0 - private_class_method def _exp_taylor(x, prec) # :nodoc: - xn = BigDecimal(1) - y = BigDecimal(1) - 1.step do |i| - n = prec + xn.exponent - break if n <= 0 || xn.zero? - xn = xn.mult(x, n).div(i, n) - y = y.add(xn, prec) - end - y + private_class_method def _exp_binary_splitting(x, prec) # :nodoc: + return BigDecimal(1) if x.zero? + # Find k that satisfies x**k / k! < 10**(-prec) + log10 = Math.log(10) + logx = BigDecimal::Internal.float_log(x.abs) + step = (1..).bsearch { |k| Math.lgamma(k + 1)[0] - k * logx > prec * log10 } + # exp(x)-1 = x*(1+x/2*(1+x/3*(1+x/4*(1+x/5*(1+...))))) + 1 + BigDecimal::Internal.taylor_sum_binary_splitting(x, [*1..step], prec) end # call-seq: @@ -341,11 +352,21 @@ def exp(x, prec) prec2 = prec + BigDecimal.double_fig + cnt x = x._decimal_shift(-cnt) - # Calculation of exp(small_prec) is fast because calculation of x**n is fast - # Calculation of exp(small_abs) converges fast. - # exp(x) = exp(small_prec_part + small_abs_part) = exp(small_prec_part) * exp(small_abs_part) - x_small_prec = x.round(Integer.sqrt(prec2)) - y = _exp_taylor(x_small_prec, prec2).mult(_exp_taylor(x.sub(x_small_prec, prec2), prec2), prec2) + # Decimal form of bit-burst algorithm + # Calculate exp(x.xxxxxxxxxxxxxxxx) as + # exp(x.xx) * exp(0.00xx) * exp(0.0000xxxx) * exp(0.00000000xxxxxxxx) + x = x.mult(1, prec2) + n = 2 + y = BigDecimal(1) + BigDecimal.save_limit do + BigDecimal.limit(0) + while x != 0 do + partial_x = x.truncate(n) + x -= partial_x + y = y.mult(_exp_binary_splitting(partial_x, prec2), prec2) + n *= 2 + end + end # calculate exp(x * 10**cnt) from exp(x) # exp(x * 10**k) = exp(x * 10**(k - 1)) ** 10 diff --git a/lib/bigdecimal/math.rb b/lib/bigdecimal/math.rb index d0d49cb8..89bf9e12 100644 --- a/lib/bigdecimal/math.rb +++ b/lib/bigdecimal/math.rb @@ -88,6 +88,37 @@ def sqrt(x, prec) end end + private_class_method def _sin_binary_splitting(x, prec) # :nodoc: + return x if x.zero? + x2 = x.mult(x, prec) + # Find k that satisfies x2**k / (2k+1)! < 10**(-prec) + log10 = Math.log(10) + logx = BigDecimal::Internal.float_log(x.abs) + step = (1..).bsearch { |k| Math.lgamma(2 * k + 1)[0] - 2 * k * logx > prec * log10 } + # Construct denominator sequence for binary splitting + # sin(x) = x*(1-x2/(2*3)*(1-x2/(4*5)*(1-x2/(6*7)*(1-x2/(8*9)*(1-...))))) + ds = (1..step).map {|i| -(2 * i) * (2 * i + 1) } + x.mult(1 + BigDecimal::Internal.taylor_sum_binary_splitting(x2, ds, prec), prec) + end + + private_class_method def _sin_around_zero(x, prec) # :nodoc: + # Divide x into several parts + # sin(x.xxxxxxxx...) = sin(x.xx + 0.00xx + 0.0000xxxx + ...) + # Calculate sin of each part and restore sin(0.xxxxxxxx...) using addition theorem. + sin = BigDecimal(0) + cos = BigDecimal(1) + n = 2 + while x != 0 do + partial_x = x.truncate(n) + x -= partial_x + s = _sin_binary_splitting(partial_x, prec) + c = (1 - s * s).sqrt(prec) + sin, cos = (sin * c).add(cos * s, prec), (cos * c).sub(sin * s, prec) + n *= 2 + end + sin.clamp(BigDecimal(-1), BigDecimal(1)) + end + # call-seq: # cbrt(decimal, numeric) -> BigDecimal # @@ -150,26 +181,9 @@ def sin(x, prec) prec = BigDecimal::Internal.coerce_validate_prec(prec, :sin) x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :sin) return BigDecimal::Internal.nan_computation_result if x.infinite? || x.nan? - n = prec + BigDecimal.double_fig - one = BigDecimal("1") - two = BigDecimal("2") + n = prec + BigDecimal.double_fig sign, x = _sin_periodic_reduction(x, n) - x1 = x - x2 = x.mult(x,n) - y = x - d = y - i = one - z = one - while d.nonzero? && ((m = n - (y.exponent - d.exponent).abs) > 0) - m = BigDecimal.double_fig if m < BigDecimal.double_fig - x1 = -x2.mult(x1,n) - i += two - z *= (i-one) * i - d = x1.div(z,m) - y += d - end - y = BigDecimal("1") if y > 1 - y.mult(sign, prec) + _sin_around_zero(x, n).mult(sign, prec) end # call-seq: @@ -187,8 +201,9 @@ def cos(x, prec) prec = BigDecimal::Internal.coerce_validate_prec(prec, :cos) x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :cos) return BigDecimal::Internal.nan_computation_result if x.infinite? || x.nan? - sign, x = _sin_periodic_reduction(x, prec + BigDecimal.double_fig, add_half_pi: true) - sign * sin(x, prec) + n = prec + BigDecimal.double_fig + sign, x = _sin_periodic_reduction(x, n, add_half_pi: true) + _sin_around_zero(x, n).mult(sign, prec) end # call-seq: @@ -277,28 +292,21 @@ def atan(x, prec) x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :atan) return BigDecimal::Internal.nan_computation_result if x.nan? n = prec + BigDecimal.double_fig - pi = PI(n) + return PI(n).div(2 * x.infinite?, prec) if x.infinite? + x = -x if neg = x < 0 - return pi.div(neg ? -2 : 2, prec) if x.infinite? - return pi.div(neg ? -4 : 4, prec) if x.round(prec) == 1 - x = BigDecimal("1").div(x, n) if inv = x > 1 - x = (-1 + sqrt(1 + x.mult(x, n), n)).div(x, n) if dbl = x > 0.5 - y = x - d = y - t = x - r = BigDecimal("3") - x2 = x.mult(x,n) - while d.nonzero? && ((m = n - (y.exponent - d.exponent).abs) > 0) - m = BigDecimal.double_fig if m < BigDecimal.double_fig - t = -t.mult(x2,n) - d = t.div(r,m) - y += d - r += 2 + x = BigDecimal(1).div(x, n) if inv = x < -1 || x > 1 + + # Solve tan(y) - x = 0 with Newton's method + # Repeat: y -= (tan(y) - x) * cos(y)**2 + y = BigDecimal(Math.atan(x.to_f), 0) + BigDecimal::Internal.newton_loop(n) do |p| + s = sin(y, p) + c = (1 - s * s).sqrt(p) + y = y.sub(c * (s.sub(c * x.mult(1, p), p)), p) end - y *= 2 if dbl - y = pi / 2 - y if inv - y = -y if neg - y.mult(1, prec) + y = PI(n) / 2 - y if inv + y.mult(neg ? -1 : 1, prec) end # call-seq: @@ -568,6 +576,300 @@ def expm1(x, prec) exp_prec > 0 ? exp(x, exp_prec).sub(1, prec) : BigDecimal(-1) end + # call-seq: + # erf(decimal, numeric) -> BigDecimal + # + # Computes the error function of +decimal+ to the specified number of digits of + # precision, +numeric+. + # + # If +decimal+ is NaN, returns NaN. + # + # BigMath.erf(BigDecimal('1'), 32).to_s + # #=> "0.84270079294971486934122063508261e0" + # + def erf(x, prec) + prec = BigDecimal::Internal.coerce_validate_prec(prec, :erf) + x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :erf) + return BigDecimal::Internal.nan_computation_result if x.nan? + return BigDecimal(x.infinite?) if x.infinite? + return BigDecimal(0) if x == 0 + return -erf(-x, prec) if x < 0 + return BigDecimal(1) if x > 5000000000 # erf(5000000000) > 1 - 1e-10000000000000000000 + + if x > 8 + xf = x.to_f + log10_erfc = -xf ** 2 / Math.log(10) - Math.log10(xf * Math::PI ** 0.5) + erfc_prec = [prec + log10_erfc.ceil, 1].max + erfc = _erfc_bit_burst(x, erfc_prec + BigDecimal.double_fig) + return BigDecimal(1).sub(erfc, prec) if erfc + end + + _erf_bit_burst(x, prec + BigDecimal.double_fig).mult(1, prec) + end + + def erfc(x, prec) + prec = BigDecimal::Internal.coerce_validate_prec(prec, :erfc) + x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :erfc) + return BigDecimal::Internal.nan_computation_result if x.nan? + return BigDecimal(1 - x.infinite?) if x.infinite? + return BigDecimal(1).sub(erf(x, prec), prec) if x < 0 + return BigDecimal(0) if x > 5000000000 # erfc(5000000000) < 1e-10000000000000000000 (underflow) + + if x >= 8 + y = _erfc_bit_burst(x, prec + BigDecimal.double_fig) + return y.mult(1, prec) if y + end + + # erfc(x) = 1 - erf(x) < exp(-x**2)/x/sqrt(pi) + # Precision of erf(x) needs about log10(exp(-x**2)) extra digits + log10 = 2.302585092994046 + high_prec = prec + BigDecimal.double_fig + (x.to_f**2 / log10).ceil + BigDecimal(1).sub(_erf_bit_burst(x, high_prec), prec) + end + + # Calculates erf(x) using bit-burst algorithm. + private_class_method def _erf_bit_burst(x, prec) + x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :erf) + prec = BigDecimal::Internal.coerce_validate_prec(prec, :erf) + + return BigDecimal(0) if x > 5000000000 # erfc underflows + x = x.mult(1, [prec - (x.to_f**2/Math.log(10)).floor, 1].max) + + calculated_x = BigDecimal(0) + erf_exp2 = BigDecimal(0) + digits = 8 + xf = x.to_f + scale = 2 * exp(-x.mult(x, prec), prec).div(PI(prec).sqrt(prec), prec) + + until x.zero? + partial = x.truncate(digits) + digits *= 2 + next if partial.zero? + + erf_exp2 = _erf_exp2_binary_splitting(partial, calculated_x, erf_exp2, prec) + calculated_x += partial + x -= partial + end + erf_exp2.mult(scale, prec) + end + + # Calculates erfc(x) using bit-burst algorithm. + private_class_method def _erfc_bit_burst(x, prec) # :nodoc: + digits = (x.exponent + 1) * 40 + + calculated_x = x.truncate(digits) + f = _erfc_exp2_asymptotic_binary_splitting(calculated_x, prec) + return unless f + + scale = 2 * exp(-x.mult(x, prec), prec).div(PI(prec).sqrt(prec), prec) + x -= calculated_x + + until x.zero? + digits *= 2 + partial = x.truncate(digits) + next if partial.zero? + + f = _erfc_exp2_inv_inv_binary_splitting(partial, calculated_x, f, prec) + calculated_x += partial + x -= partial + end + f.mult(scale, prec) + end + + # Matrix multiplication for binary splitting method in erf/erfc calculation + private_class_method def _bs_matrix_mult(m1, m2, size, prec) # :nodoc: + (size * size).times.map do |i| + size.times.map do |k| + m1[i / size * size + k].mult(m2[size * k + i % size], prec) + end.reduce {|a, b| a.add(b, prec) } + end + end + + # Matrix/Vector weighted sum for binary splitting method in erf/erfc calculation + private_class_method def _bs_weighted_sum(m1, w1, m2, w2, prec) # :nodoc: + m1.zip(m2).map {|v1, v2| (v1 * w1).add(v2 * w2, prec) } + end + + # Calculates Taylor expansion of erf(x+a)*exp((x+a)**2)*sqrt(pi)/2 with binary splitting method. + private_class_method def _erf_exp2_binary_splitting(x, a, f_a, prec) # :nodoc: + cexponent = Math.log10([2 * a, Math.sqrt(2)].max.to_f) + log10(x.abs, 10) + log10f = Math.log(10) + + steps = BigDecimal.save_exception_mode do + BigDecimal.mode(BigDecimal::EXCEPTION_UNDERFLOW, false) + (2..).bsearch do |n| + x.to_f ** 2 < n && n * cexponent + Math.lgamma(n / 2)[0] / log10f + n * Math.log10(2) - Math.lgamma(n - 1)[0] / log10f < -prec + x.to_f**2 / log10f + end + end + + if a == 0 + denominators = (steps / 2).times.map {|i| 2 * i + 3 } + return x.mult(1 + BigDecimal::Internal.taylor_sum_binary_splitting(2 * x * x, denominators, prec), prec) + end + + # First, calculate a matrix that represents the sum of the Taylor series: + # SumMatrix = (((((...+I)x*M4+I)*x*M3+I)*M2*x+I)*M1*x+I) + # Where Mi is a 2x2 matrix that generates the next coefficients of Taylor series: + # Vector(c4, c5) = M4*M3*M2*M1*Vector(c0, c1) + # And then calculates: + # SumMatrix * Vector(c0, c1) = Vector(c0+c1*x+c2*x**2+..., _) + # In this binary splitting method, adjacent two operations are combined into one repeatedly. + # ((...) * x * A + B) / C is the form of each operation. A and B are 2x2 matrices, C is a scalar. + zero = BigDecimal(0) + two = BigDecimal(2) + two_a = two * a + operations = steps.times.map do |i| + n = BigDecimal(2 + i) + [[zero, n, two, two_a], [n, zero, zero, n], n] + end + + while operations.size > 1 + xpow = xpow ? xpow.mult(xpow, prec) : x.mult(1, prec) + operations = operations.each_slice(2).map do |op1, op2| + # Combine two operations into one: + # (((Remaining * x * A2 + B2) / C2) * x * A1 + B1) / C1 + # ((Remaining * (x*x) * (A2*A1) + (x*B2*A1+B1*C2)) / (C1*C2) + # Therefore, combined operation can be represented as: + # Anext = A2 * A1 + # Bnext = x * B2 * A1 + B1 * C2 + # Cnext = C1 * C2 + # xnext = x * x + a1, b1, c1 = op1 + a2, b2, c2 = op2 || [[zero] * 4, [zero] * 4, BigDecimal(1)] + [ + _bs_matrix_mult(a2, a1, 2, prec), + _bs_weighted_sum(_bs_matrix_mult(b2, a1, 2, prec), xpow, b1, c2, prec), + c1.mult(c2, prec), + ] + end + end + _, sum_matrix, denominator = operations.first + (sum_matrix[1] + f_a * (2 * a * sum_matrix[1] + sum_matrix[0])).div(denominator, prec) + end + + # Calculates asymptotic expansion of erfc(x)*exp(x**2)*sqrt(pi)/2 with binary splitting method + private_class_method def _erfc_exp2_asymptotic_binary_splitting(x, prec) # :nodoc: + # Let f(x) = erfc(x)*sqrt(pi)*exp(x**2)/2 + # f(x) satisfies the following differential equation: + # 2*x*f(x) = f'(x) + 1 + # From the above equation, we can derive the following asymptotic expansion: + # f(x) = (0..kmax).sum { (-1)**k * (2*k)! / 4**k / k! / x**(2*k)) } / x + + # This asymptotic expansion does not converge. + # But if there is a k that satisfies (2*k)! / 4**k / k! / x**(2*k) < 10**(-prec), + # It is enough to calculate erfc within the given precision. + # Using Stirling's approximation, we can simplify this condition to: + # sqrt(2)/2 + k*log(k) - k - 2*k*log(x) < -prec*log(10) + # and the left side is minimized when k = x**2. + xf = x.to_f + kmax = (1..(xf ** 2).floor).bsearch do |k| + Math.log(2) / 2 + k * Math.log(k) - k - 2 * k * Math.log(xf) < -prec * Math.log(10) + end + return unless kmax + + # Convert asymptotic expansion to nested form: + # 1 + a/x + a*b/x/x + a*b*c/x/x/x + a*b*c/x/x/x*rest + # = 1 + (a/x) * (1 + (b/x) * (1 + (c/x) * (1 + rest))) + # + # And calculate it with binary splitting: + # (a1/d + b1/d * (a2/d + b2/d * (rest))) + # = ((a1*d+b1*a2)/(d*d) + b1*b2/(d*denominator) * (rest))) + denominator = x.mult(x, prec).mult(2, prec) + fractions = (1..kmax).map do |k| + [denominator, BigDecimal(1 - 2 * k)] + end + while fractions.size > 1 + fractions = fractions.each_slice(2).map do |fraction1, fraction2| + a1, b1 = fraction1 + a2, b2 = fraction2 || [BigDecimal(0), denominator] + [ + a1.mult(denominator, prec).add(b1.mult(a2, prec), prec), + b1.mult(b2, prec), + ] + end + denominator = denominator.mult(denominator, prec) + end + sum = fractions[0][0].add(fractions[0][1], prec).div(denominator, prec) + sum.div(x, prec) / 2 + end + + # Calculates f(1/(a + x)) where f(x) = (sqrt(pi)/2) * exp(1/x**2) * erfc(1/x) + # f(1/(a+x)) = f(1/a - x/(a*(a+x))) + private_class_method def _erfc_exp2_inv_inv_binary_splitting(x, a, f_inva, prec) + return f_inva if x.zero? + # f(x) satisfies the following differential equation: + # (1/a+w)**3*f'(1/a+w) + 2*f(1/a+w) = 1/a + w + # From the above equation, we can derive the following Taylor expansion around x=a: + # Coefficients: f(1/a + w) = c0 + c1*w + c2*w**2 + c3*w**3 + ... + # Constraints: + # (w**3 + 3*w**2/a + 3*w/a**2 + 1/a**3) * (c1 + 2*c2*w + 3*c3*w**2 + 4*c4*w**3 + ...) + # + 2 * (c0 + c1*w + c2*w**2 + c3*w**3 + ...) = 1/a + w + # Recurrence relations: + # c0 = f(1/a) + # c1 = a**2 - 2*c0*a**3 + # c2 = (a**3 - 3*c1*a - 2*c1*a**3) / 2 + # c3 = -(3*c1*a**2 + 6*c2*a + 2*c2*a**3) / 3 + # c(n) = -((n-3)*c(n-3)*a**3 + 3*(n-2)*c(n-2)*a**2 + 3*(n-1)*c(n-1)*a + 2*c(n-1)*a**3) / n + + aa = a.mult(a, prec) + aaa = aa.mult(a, prec) + c0 = f_inva + c1 = (aa - 2 * c0 * aaa).mult(1, prec) + c2 = (aaa - 3 * c1 * a - 2 * c1 * aaa).div(2, prec) + + # Estimate the number of steps needed to achieve the required precision + low_prec = 16 + w = x.div(a.mult(a + x, low_prec), low_prec) + wpow = w.mult(w, low_prec) + cm3, cm2, cm1 = [c0, c1, c2].map {|v| v.mult(1, low_prec) } + a_low, aa_low, aaa_low = [a, aa, aaa].map {|v| v.mult(1, low_prec) } + step = (3..).find do |n| + wpow = wpow.mult(w, low_prec) + cn = -((n - 3) * cm3 * aaa_low + 3 * aa_low * (n - 2) * cm2 + 3 * a_low * (n - 1) * cm1 + 2 * cm1 * aaa_low).div(n, low_prec) + cm3, cm2, cm1 = cm2, cm1, cn + cn.mult(wpow, low_prec).exponent < -prec + end + + # Let M(n) be a 3x3 matrix that transforms (c(n),c(n+1),c(n+2)) to (c(n-1),c(n),c(n+1)) + # Mn = | 0 1 0 | + # | 0 0 1 | + # | -(n-3)*aaa/n -3*(n-2)*aa/n -2*aaa-3*(n-1)*a/n | + # Vector(c6,c7,c8) = M6*M5*M4*M3*M2*M1 * Vector(c0,c1,c2) + # Vector(c0+c1*y/z+c2*(y/z)**2+..., _, _) = (((... + I)*M3*y/z + I)*M2*y/z + I)*M1*y/z + I) * Vector(c2, c1, c0) + # Perform binary splitting on this nested parenthesized calculation by using the following formula: + # (((...)*A2*y/z + B2)/D2 * A1*y/z + B1)/D1 = (((...)*(A2*A1)*(y*y)/z + (B2*A1*y+z*D2*B1)) / (D1*D2*z) + # where A_n, Bn are matrices and Dn are scalars + + zero = BigDecimal(0) + operations = (3..step + 2).map do |n| + bign = BigDecimal(n) + [ + [ + zero, bign, zero, + zero, zero, bign, + BigDecimal(-(n - 3) * aaa), -3 * (n - 2) * aa, -2 * aaa - 3 * (n - 1) * a + ], + [bign, zero, zero, zero, bign, zero, zero, zero, bign], + bign + ] + end + + z = a.mult(a + x, prec) + while operations.size > 1 + y = y ? y.mult(y, prec) : -x.mult(1, prec) + operations = operations.each_slice(2).map do |op1, op2| + a1, b1, d1 = op1 + a2, b2, d2 = op2 || [[zero] * 9, [zero] * 9, BigDecimal(1)] + [ + _bs_matrix_mult(a2, a1, 3, prec), + _bs_weighted_sum(_bs_matrix_mult(b2, a1, 3, prec), y, b1, d2.mult(z, prec), prec), + d1.mult(d2, prec).mult(z, prec), + ] + end + end + _, sum_matrix, denominator = operations[0] + (sum_matrix[0] * c0 + sum_matrix[1] * c1 + sum_matrix[2] * c2).div(denominator, prec) + end # call-seq: # PI(numeric) -> BigDecimal @@ -580,38 +882,18 @@ def expm1(x, prec) # def PI(prec) prec = BigDecimal::Internal.coerce_validate_prec(prec, :PI) - n = prec + BigDecimal.double_fig - zero = BigDecimal("0") - one = BigDecimal("1") - two = BigDecimal("2") - - m25 = BigDecimal("-0.04") - m57121 = BigDecimal("-57121") - - pi = zero - - d = one - k = one - t = BigDecimal("-80") - while d.nonzero? && ((m = n - (pi.exponent - d.exponent).abs) > 0) - m = BigDecimal.double_fig if m < BigDecimal.double_fig - t = t*m25 - d = t.div(k,m) - k = k+two - pi = pi + d - end - - d = one - k = one - t = BigDecimal("956") - while d.nonzero? && ((m = n - (pi.exponent - d.exponent).abs) > 0) - m = BigDecimal.double_fig if m < BigDecimal.double_fig - t = t.div(m57121,n) - d = t.div(k,m) - pi = pi + d - k = k+two + n = prec + BigDecimal.double_fig + a = BigDecimal(1) + b = BigDecimal(0.5, 0).sqrt(n) + s = BigDecimal(0.25, 0) + t = 1 + while a != b && (a - b).exponent > 1 - n + c = (a - b).div(2, n) + a, b = (a + b).div(2, n), (a * b).sqrt(n) + s = s.sub(c * c * t, n) + t *= 2 end - pi.mult(1, prec) + (a * b).div(s, prec) end # call-seq: diff --git a/test/bigdecimal/test_bigmath.rb b/test/bigdecimal/test_bigmath.rb index 39dee611..9cb27e85 100644 --- a/test/bigdecimal/test_bigmath.rb +++ b/test/bigdecimal/test_bigmath.rb @@ -182,8 +182,13 @@ def test_sin assert_converge_in_precision {|n| sin(BigDecimal("1e-30"), n) } assert_converge_in_precision {|n| sin(BigDecimal(PI(50)), n) } assert_converge_in_precision {|n| sin(BigDecimal(PI(50) * 100), n) } - assert_operator(sin(PI(30) / 2, 30), :<=, 1) - assert_operator(sin(-PI(30) / 2, 30), :>=, -1) + [:up, :down].each do |mode| + BigDecimal.save_rounding_mode do + BigDecimal.mode(BigDecimal::ROUND_MODE, mode) + assert_operator(sin(PI(30) / 2, 30), :<=, 1) + assert_operator(sin(-PI(30) / 2, 30), :>=, -1) + end + end end def test_cos @@ -205,8 +210,13 @@ def test_cos assert_converge_in_precision {|n| cos(BigDecimal("1e50"), n) } assert_converge_in_precision {|n| cos(BigDecimal(PI(50) / 2), n) } assert_converge_in_precision {|n| cos(BigDecimal(PI(50) * 201 / 2), n) } - assert_operator(cos(PI(30), 30), :>=, -1) - assert_operator(cos(PI(30) * 2, 30), :<=, 1) + [:up, :down].each do |mode| + BigDecimal.save_rounding_mode do + BigDecimal.mode(BigDecimal::ROUND_MODE, mode) + assert_operator(cos(PI(30), 30), :>=, -1) + assert_operator(cos(PI(30) * 2, 30), :<=, 1) + end + end end def test_tan @@ -388,26 +398,20 @@ def test_exp def test_log assert_equal(0, log(BigDecimal("1.0"), 10)) - assert_in_epsilon(Math.log(10)*1000, log(BigDecimal("1e1000"), 10)) + assert_in_epsilon(1000 * Math.log(10), log(BigDecimal("1e1000"), 10)) + assert_in_epsilon(19999999999999 * Math.log(10), log(BigDecimal("1E19999999999999"), 10)) + assert_in_epsilon(-19999999999999 * Math.log(10), log(BigDecimal("1E-19999999999999"), 10)) assert_in_exact_precision( BigDecimal("2.3025850929940456840179914546843642076011014886287729760333279009675726096773524802359972050895982983419677840422862"), log(BigDecimal("10"), 100), 100 ) assert_converge_in_precision {|n| log(BigDecimal("2"), n) } - assert_converge_in_precision {|n| log(BigDecimal("1e-30") + 1, n) } - assert_converge_in_precision {|n| log(BigDecimal("1e-30"), n) } + assert_converge_in_precision {|n| log(1 + SQRT2 * BigDecimal("1e-30"), n) } + assert_converge_in_precision {|n| log(SQRT2 * BigDecimal("1e-30"), n) } assert_converge_in_precision {|n| log(BigDecimal("1e30"), n) } assert_converge_in_precision {|n| log(SQRT2, n) } assert_raise(Math::DomainError) {log(BigDecimal("-0.1"), 10)} - begin - x = BigDecimal("1E19999999999999") - rescue FloatDomainError - else - unless x.infinite? - assert_in_epsilon(Math.log(10) * 19999999999999, BigMath.log(x, 10)) - end - end end def test_log2 @@ -469,4 +473,65 @@ def test_expm1 assert_in_exact_precision(exp(BigDecimal("1.23e-10"), 120) - 1, expm1(BigDecimal("1.23e-10"), 100), 100) assert_in_exact_precision(exp(123, 120) - 1, expm1(BigDecimal("123"), 100), 100) end + + def test_erf + [-0.5, 0.1, 0.3, 2.1, 3.3].each do |x| + assert_in_epsilon(Math.erf(x), BigMath.erf(BigDecimal(x.to_s), N)) + end + assert_equal(1, BigMath.erf(PINF, N)) + assert_equal(-1, BigMath.erf(MINF, N)) + assert_equal(1, BigMath.erf(BigDecimal(1000), 100)) + assert_equal(-1, BigMath.erf(BigDecimal(-1000), 100)) + assert_not_equal(1, BigMath.erf(BigDecimal(10), 45)) + assert_not_equal(1, BigMath.erf(BigDecimal(15), 100)) + assert_equal(1, BigMath.erf(BigDecimal('1e400'), 10)) + assert_equal(-1, BigMath.erf(BigDecimal('-1e400'), 10)) + assert_equal( + BigDecimal("0.9953222650189527341620692563672529286108917970400600767383523262004372807199951773676290080196806805"), + BigMath.erf(BigDecimal("2"), 100) + ) + assert_converge_in_precision {|n| BigMath.erf(BigDecimal("1e-30"), n) } + assert_converge_in_precision {|n| BigMath.erf(BigDecimal("0.3"), n) } + assert_converge_in_precision {|n| BigMath.erf(SQRT2, n) } + end + + def test_erfc + [-0.5, 0.1, 0.3, 2.1, 3.3].each do |x| + assert_in_epsilon(Math.erfc(x), BigMath.erfc(BigDecimal(x.to_s), N)) + end + assert_equal(0, BigMath.erfc(PINF, N)) + assert_equal(2, BigMath.erfc(MINF, N)) + assert_equal(0, BigMath.erfc(BigDecimal('1e400'), 10)) + assert_equal(2, BigMath.erfc(BigDecimal('-1e400'), 10)) + + # erfc with taylor series + assert_equal( + BigDecimal("2.088487583762544757000786294957788611560818119321163727012213713938174695833440290610766384285723554e-45"), + BigMath.erfc(BigDecimal("10"), 100) + ) + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(0.3), n) } + assert_converge_in_precision {|n| BigMath.erfc(SQRT2, n) } + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(8), n) } + # erfc with asymptotic expansion + assert_equal( + BigDecimal("1.896961059966276509268278259713415434936907563929186183462834752900411805205111886605256690776760041e-697"), + BigMath.erfc(BigDecimal("40"), 100) + ) + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(30), n) } + assert_converge_in_precision {|n| BigMath.erfc(30 * SQRT2, n) } + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(50), n) } + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(60000), n) } + # Near crossover point between taylor series and asymptotic expansion around prec=150 + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(19.5), n) } + assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(20.5), n) } + end + + def test_erf_erfc_consistency_large_prec + [BigDecimal(34.5), 34 + BigDecimal(4).div(7, 1200)].each do |x| + erf = BigMath.erf(x, 1200) # Calculated with taylor series of erf + erfc = BigMath.erfc(x, 400) # Calculated with asymptotic expansion + erfc2 = 1 - erf + assert_equal(erfc, erfc2.mult(1, 400)) + end + end end diff --git a/test/bigdecimal/test_vp_operation.rb b/test/bigdecimal/test_vp_operation.rb index 075df0b6..389aba77 100644 --- a/test/bigdecimal/test_vp_operation.rb +++ b/test/bigdecimal/test_vp_operation.rb @@ -13,6 +13,10 @@ def setup end end + def ntt_mult_available? + BASE_FIG == 9 + end + def test_vpmult assert_equal(BigDecimal('121932631112635269'), BigDecimal('123456789').vpmult(BigDecimal('987654321'))) assert_equal(BigDecimal('12193263.1112635269'), BigDecimal('123.456789').vpmult(BigDecimal('98765.4321'))) @@ -21,6 +25,68 @@ def test_vpmult assert_equal(BigDecimal("#{x * y}e-300"), BigDecimal("#{x}e-100").vpmult(BigDecimal("#{y}e-200"))) end + def test_nttmult + omit 'NTT multiplication is only available for 32-bit DECDIG' unless ntt_mult_available? + [*1..32].repeated_permutation(2) do |a, b| + x = BigDecimal(10 ** (BASE_FIG * a) / 7) + y = BigDecimal(10 ** (BASE_FIG * b) / 13) + assert_equal(x.to_i * y.to_i, x.nttmult(y)) + end + end + + def test_newton_inverse + xs = [BigDecimal(3), BigDecimal('123e50'), BigDecimal('13' * 44), BigDecimal('17' * 45), BigDecimal('19' * 46)] + %i[up half_up down].each do |rounding_mode| + BigDecimal.save_rounding_mode do + BigDecimal.mode(BigDecimal::ROUND_MODE, rounding_mode) + [*1..32, 50, 100, 200, 300].each do |prec| + xs.each do |x| + inv = x.newton_raphson_inverse(prec) + assert_in_delta(1, x * inv, BigDecimal("1e#{1 - prec}")) + + high_precision_inv = inv * (2 - x * inv) + expected_inv = high_precision_inv.mult(1, prec) + last_digit = BigDecimal("1e#{expected_inv.exponent - prec}") + assert_include([expected_inv - last_digit, expected_inv, expected_inv + last_digit], inv) + end + end + end + end + end + + def test_not_affected_by_limit + x_int = 123**135 + y_int = 135**123 + xy_int = x_int * y_int + mod_int = 111**111 + x = BigDecimal(x_int) + y = BigDecimal(y_int) + xy = BigDecimal(xy_int) + mod = BigDecimal(mod_int) + z = BigDecimal(xy_int + mod_int) + BigDecimal.save_limit do + BigDecimal.limit 3 + assert_equal(xy, x.vpmult(y)) + assert_equal(3, BigDecimal.limit) + if ntt_mult_available? + assert_equal(xy, x.nttmult(y)) + assert_equal(3, BigDecimal.limit) + end + + prec = (z.exponent - 1) / BASE_FIG - (y.exponent - 1) / BASE_FIG + 1 + assert_equal([x, mod], z.vpdivd(y, prec)) + assert_equal(3, BigDecimal.limit) + assert_equal([x, mod], z.vpdivd_newton(y, prec)) + assert_equal(3, BigDecimal.limit) + end + end + + def assert_vpdivd_equal(expected_divmod, x_y_n) + x, *args = x_y_n + assert_equal(expected_divmod, x.vpdivd(*args)) + assert_equal(expected_divmod, x.vpdivd_newton(*args)) + end + def test_vpdivd # a[0] > b[0] # XXXX_YYYY_ZZZZ / 1111 #=> 000X_000Y_000Z @@ -31,11 +97,11 @@ def test_vpdivd d3 = BigDecimal("4e#{BASE_FIG * 2}") + d2 d4 = BigDecimal("5e#{BASE_FIG}") + d3 d5 = BigDecimal(6) + d4 - assert_equal([d1, x1 - d1 * y], x1.vpdivd(y, 1)) - assert_equal([d2, x1 - d2 * y], x1.vpdivd(y, 2)) - assert_equal([d3, x1 - d3 * y], x1.vpdivd(y, 3)) - assert_equal([d4, x1 - d4 * y], x1.vpdivd(y, 4)) - assert_equal([d5, x1 - d5 * y], x1.vpdivd(y, 5)) + assert_vpdivd_equal([d1, x1 - d1 * y], [x1, y, 1]) + assert_vpdivd_equal([d2, x1 - d2 * y], [x1, y, 2]) + assert_vpdivd_equal([d3, x1 - d3 * y], [x1, y, 3]) + assert_vpdivd_equal([d4, x1 - d4 * y], [x1, y, 4]) + assert_vpdivd_equal([d5, x1 - d5 * y], [x1, y, 5]) # a[0] < b[0] # 00XX_XXYY_YYZZ_ZZ00 / 1111 #=> 0000_0X00_0Y00_0Z00 @@ -46,28 +112,28 @@ def test_vpdivd d3 = BigDecimal("4e#{2 * BASE_FIG + shift}") + d2 d4 = BigDecimal("5e#{BASE_FIG + shift}") + d3 d5 = BigDecimal("6e#{shift}") + d4 - assert_equal([0, x2], x2.vpdivd(y, 1)) - assert_equal([d1, x2 - d1 * y], x2.vpdivd(y, 2)) - assert_equal([d2, x2 - d2 * y], x2.vpdivd(y, 3)) - assert_equal([d3, x2 - d3 * y], x2.vpdivd(y, 4)) - assert_equal([d4, x2 - d4 * y], x2.vpdivd(y, 5)) - assert_equal([d5, x2 - d5 * y], x2.vpdivd(y, 6)) + assert_vpdivd_equal([0, x2], [x2, y, 1]) + assert_vpdivd_equal([d1, x2 - d1 * y], [x2, y, 2]) + assert_vpdivd_equal([d2, x2 - d2 * y], [x2, y, 3]) + assert_vpdivd_equal([d3, x2 - d3 * y], [x2, y, 4]) + assert_vpdivd_equal([d4, x2 - d4 * y], [x2, y, 5]) + assert_vpdivd_equal([d5, x2 - d5 * y], [x2, y, 6]) end def test_vpdivd_large_quotient_prec # 0001 / 0003 = 0000_3333_3333 - assert_equal([BigDecimal('0.' + '3' * BASE_FIG * 9), BigDecimal("1e-#{9 * BASE_FIG}")], BigDecimal(1).vpdivd(BigDecimal(3), 10)) + assert_vpdivd_equal([BigDecimal('0.' + '3' * BASE_FIG * 9), BigDecimal("1e-#{9 * BASE_FIG}")], [BigDecimal(1), BigDecimal(3), 10]) # 1000 / 0003 = 0333_3333_3333 - assert_equal([BigDecimal('3' * (BASE_FIG - 1) + '.' + '3' * BASE_FIG * 9), BigDecimal("1e-#{9 * BASE_FIG}")], BigDecimal(BASE / 10).vpdivd(BigDecimal(3), 10)) + assert_vpdivd_equal([BigDecimal('3' * (BASE_FIG - 1) + '.' + '3' * BASE_FIG * 9), BigDecimal("1e-#{9 * BASE_FIG}")], [BigDecimal(BASE / 10), BigDecimal(3), 10]) end def test_vpdivd_with_one x = BigDecimal('1234.2468000001234') - assert_equal([BigDecimal('1234'), BigDecimal('0.2468000001234')], x.vpdivd(BigDecimal(1), 1)) - assert_equal([BigDecimal('+1234.2468'), BigDecimal('+0.1234e-9')], (+x).vpdivd(BigDecimal(+1), 2)) - assert_equal([BigDecimal('-1234.2468'), BigDecimal('+0.1234e-9')], (+x).vpdivd(BigDecimal(-1), 2)) - assert_equal([BigDecimal('-1234.2468'), BigDecimal('-0.1234e-9')], (-x).vpdivd(BigDecimal(+1), 2)) - assert_equal([BigDecimal('+1234.2468'), BigDecimal('-0.1234e-9')], (-x).vpdivd(BigDecimal(-1), 2)) + assert_vpdivd_equal([BigDecimal('1234'), BigDecimal('0.2468000001234')], [x, BigDecimal(1), 1]) + assert_vpdivd_equal([BigDecimal('+1234.2468'), BigDecimal('+0.1234e-9')], [+x, BigDecimal(+1), 2]) + assert_vpdivd_equal([BigDecimal('-1234.2468'), BigDecimal('+0.1234e-9')], [+x, BigDecimal(-1), 2]) + assert_vpdivd_equal([BigDecimal('-1234.2468'), BigDecimal('-0.1234e-9')], [-x, BigDecimal(+1), 2]) + assert_vpdivd_equal([BigDecimal('+1234.2468'), BigDecimal('-0.1234e-9')], [-x, BigDecimal(-1), 2]) end def test_vpdivd_precisions @@ -79,7 +145,7 @@ def test_vpdivd_precisions yn = (y.digits.size + BASE_FIG - 1) / BASE_FIG base = BASE ** (n - xn + yn - 1) div = BigDecimal((x * base / y).to_i) / base - assert_equal([div, x - y * div], BigDecimal(x).vpdivd(y, n)) + assert_vpdivd_equal([div, x - y * div], [BigDecimal(x), BigDecimal(y), n]) end end end @@ -92,7 +158,7 @@ def test_vpdivd_borrow x = y * (3 * BASE**4 + a * BASE**3 + b * BASE**2 + c * BASE + d) / BASE div = BigDecimal(x * BASE / y) / BASE mod = BigDecimal(x) - div * y - assert_equal([div, mod], BigDecimal(x).vpdivd(BigDecimal(y), 5)) + assert_vpdivd_equal([div, mod], [BigDecimal(x), BigDecimal(y), 5]) end end end @@ -104,22 +170,22 @@ def test_vpdivd_large_prec_divisor divy1_1 = BigDecimal(2) divy2_1 = BigDecimal(1) divy2_2 = BigDecimal('1.' + '9' * BASE_FIG) - assert_equal([divy1_1, x - y1 * divy1_1], x.vpdivd(y1, 1)) - assert_equal([divy2_1, x - y2 * divy2_1], x.vpdivd(y2, 1)) - assert_equal([divy2_2, x - y2 * divy2_2], x.vpdivd(y2, 2)) + assert_vpdivd_equal([divy1_1, x - y1 * divy1_1], [x, y1, 1]) + assert_vpdivd_equal([divy2_1, x - y2 * divy2_1], [x, y2, 1]) + assert_vpdivd_equal([divy2_2, x - y2 * divy2_2], [x, y2, 2]) end def test_vpdivd_intermediate_zero if BASE_FIG == 9 x = BigDecimal('123456789.246913578000000000123456789') y = BigDecimal('123456789') - assert_equal([BigDecimal('1.000000002000000000000000001'), BigDecimal(0)], x.vpdivd(y, 4)) - assert_equal([BigDecimal('1.000000000049999999'), BigDecimal('1e-18')], BigDecimal("2.000000000099999999").vpdivd(2, 3)) + assert_vpdivd_equal([BigDecimal('1.000000002000000000000000001'), BigDecimal(0)], [x, y, 4]) + assert_vpdivd_equal([BigDecimal('1.000000000049999999'), BigDecimal('1e-18')], [BigDecimal("2.000000000099999999"), 2, 3]) else x = BigDecimal('1234.246800001234') y = BigDecimal('1234') - assert_equal([BigDecimal('1.000200000001'), BigDecimal(0)], x.vpdivd(y, 4)) - assert_equal([BigDecimal('1.00000499'), BigDecimal('1e-8')], BigDecimal("2.00000999").vpdivd(2, 3)) + assert_vpdivd_equal([BigDecimal('1.000200000001'), BigDecimal(0)], [x, y, 4]) + assert_vpdivd_equal([BigDecimal('1.00000499'), BigDecimal('1e-8')], [BigDecimal("2.00000999"), 2, 3]) end end end