Skip to content

Commit ace2d32

Browse files
committed
BigMath.erf and BigMath.erfc with bit burst algorithm
1 parent a3ea7d6 commit ace2d32

File tree

2 files changed

+366
-31
lines changed

2 files changed

+366
-31
lines changed

lib/bigdecimal/math.rb

Lines changed: 305 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,300 @@ def expm1(x, prec)
576576
exp_prec > 0 ? exp(x, exp_prec).sub(1, prec) : BigDecimal(-1)
577577
end
578578

579+
# call-seq:
580+
# erf(decimal, numeric) -> BigDecimal
581+
#
582+
# Computes the error function of +decimal+ to the specified number of digits of
583+
# precision, +numeric+.
584+
#
585+
# If +decimal+ is NaN, returns NaN.
586+
#
587+
# BigMath.erf(BigDecimal('1'), 32).to_s
588+
# #=> "0.84270079294971486934122063508261e0"
589+
#
590+
def erf(x, prec)
591+
prec = BigDecimal::Internal.coerce_validate_prec(prec, :erf)
592+
x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :erf)
593+
return BigDecimal::Internal.nan_computation_result if x.nan?
594+
return BigDecimal(x.infinite?) if x.infinite?
595+
return BigDecimal(0) if x == 0
596+
return -erf(-x, prec) if x < 0
597+
return BigDecimal(1) if x > 5000000000 # erf(5000000000) > 1 - 1e-10000000000000000000
598+
599+
if x > 8
600+
xf = x.to_f
601+
log10_erfc = -xf ** 2 / Math.log(10) - Math.log10(xf * Math::PI ** 0.5)
602+
erfc_prec = [prec + log10_erfc.ceil, 1].max
603+
erfc = _erfc_bit_burst(x, erfc_prec + BigDecimal.double_fig)
604+
return BigDecimal(1).sub(erfc, prec) if erfc
605+
end
606+
607+
_erf_bit_burst(x, prec + BigDecimal.double_fig).mult(1, prec)
608+
end
609+
610+
def erfc(x, prec)
611+
prec = BigDecimal::Internal.coerce_validate_prec(prec, :erfc)
612+
x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :erfc)
613+
return BigDecimal::Internal.nan_computation_result if x.nan?
614+
return BigDecimal(1 - x.infinite?) if x.infinite?
615+
return BigDecimal(1).sub(erf(x, prec), prec) if x < 0
616+
return BigDecimal(0) if x > 5000000000 # erfc(5000000000) < 1e-10000000000000000000 (underflow)
617+
618+
if x >= 8
619+
y = _erfc_bit_burst(x, prec + BigDecimal.double_fig)
620+
return y.mult(1, prec) if y
621+
end
622+
623+
# erfc(x) = 1 - erf(x) < exp(-x**2)/x/sqrt(pi)
624+
# Precision of erf(x) needs about log10(exp(-x**2)) extra digits
625+
log10 = 2.302585092994046
626+
high_prec = prec + BigDecimal.double_fig + (x.to_f**2 / log10).ceil
627+
BigDecimal(1).sub(_erf_bit_burst(x, high_prec), prec)
628+
end
629+
630+
# Calculates erf(x) using bit-burst algorithm.
631+
private_class_method def _erf_bit_burst(x, prec)
632+
x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :erf)
633+
prec = BigDecimal::Internal.coerce_validate_prec(prec, :erf)
634+
635+
return BigDecimal(0) if x > 5000000000 # erfc underflows
636+
x = x.mult(1, [prec - (x.to_f**2/Math.log(10)).floor, 1].max)
637+
638+
calculated_x = BigDecimal(0)
639+
erf_exp2 = BigDecimal(0)
640+
digits = 8
641+
xf = x.to_f
642+
scale = 2 * exp(-x.mult(x, prec), prec).div(PI(prec).sqrt(prec), prec)
643+
644+
until x.zero?
645+
partial = x.truncate(digits)
646+
digits *= 2
647+
next if partial.zero?
648+
649+
erf_exp2 = _erf_exp2_binary_splitting(partial, calculated_x, erf_exp2, prec)
650+
calculated_x += partial
651+
x -= partial
652+
end
653+
erf_exp2.mult(scale, prec)
654+
end
655+
656+
# Calculates erfc(x) using bit-burst algorithm.
657+
private_class_method def _erfc_bit_burst(x, prec) # :nodoc:
658+
digits = (x.exponent + 1) * 40
659+
660+
calculated_x = x.truncate(digits)
661+
f = _erfc_exp2_asymptotic_binary_splitting(calculated_x, prec)
662+
return unless f
663+
664+
scale = 2 * exp(-x.mult(x, prec), prec).div(PI(prec).sqrt(prec), prec)
665+
x -= calculated_x
666+
667+
until x.zero?
668+
digits *= 2
669+
partial = x.truncate(digits)
670+
next if partial.zero?
671+
672+
f = _erfc_exp2_inv_inv_binary_splitting(partial, calculated_x, f, prec)
673+
calculated_x += partial
674+
x -= partial
675+
end
676+
f.mult(scale, prec)
677+
end
678+
679+
# Matrix multiplication for binary splitting method in erf/erfc calculation
680+
private_class_method def _bs_matrix_mult(m1, m2, size, prec) # :nodoc:
681+
(size * size).times.map do |i|
682+
size.times.map do |k|
683+
m1[i / size * size + k].mult(m2[size * k + i % size], prec)
684+
end.reduce {|a, b| a.add(b, prec) }
685+
end
686+
end
687+
688+
# Matrix/Vector weighted sum for binary splitting method in erf/erfc calculation
689+
private_class_method def _bs_weighted_sum(m1, w1, m2, w2, prec) # :nodoc:
690+
m1.zip(m2).map {|v1, v2| (v1 * w1).add(v2 * w2, prec) }
691+
end
692+
693+
# Calculates Taylor expansion of erf(x+a)*exp((x+a)**2)*sqrt(pi)/2 with binary splitting method.
694+
private_class_method def _erf_exp2_binary_splitting(x, a, f_a, prec) # :nodoc:
695+
cexponent = Math.log10([2 * a, Math.sqrt(2)].max.to_f) + log10(x.abs, 10)
696+
log10f = Math.log(10)
697+
698+
steps = BigDecimal.save_exception_mode do
699+
BigDecimal.mode(BigDecimal::EXCEPTION_UNDERFLOW, false)
700+
(2..).bsearch do |n|
701+
x.to_f ** 2 < n && n * cexponent + Math.lgamma(n / 2)[0] / log10f + n * Math.log10(2) - Math.lgamma(n - 1)[0] / log10f < -prec + x.to_f**2 / log10f
702+
end
703+
end
704+
705+
if a == 0
706+
denominators = (steps / 2).times.map {|i| 2 * i + 3 }
707+
return x.mult(1 + BigDecimal::Internal.taylor_sum_binary_splitting(2 * x * x, denominators, prec), prec)
708+
end
709+
710+
# First, calculate a matrix that represents the sum of the Taylor series:
711+
# SumMatrix = (((((...+I)x*M4+I)*x*M3+I)*M2*x+I)*M1*x+I)
712+
# Where Mi is a 2x2 matrix that generates the next coefficients of Taylor series:
713+
# Vector(c4, c5) = M4*M3*M2*M1*Vector(c0, c1)
714+
# And then calculates:
715+
# SumMatrix * Vector(c0, c1) = Vector(c0+c1*x+c2*x**2+..., _)
716+
# In this binary splitting method, adjacent two operations are combined into one repeatedly.
717+
# ((...) * x * A + B) / C is the form of each operation. A and B are 2x2 matrices, C is a scalar.
718+
zero = BigDecimal(0)
719+
two = BigDecimal(2)
720+
two_a = two * a
721+
operations = steps.times.map do |i|
722+
n = BigDecimal(2 + i)
723+
[[zero, n, two, two_a], [n, zero, zero, n], n]
724+
end
725+
726+
while operations.size > 1
727+
xpow = xpow ? xpow.mult(xpow, prec) : x.mult(1, prec)
728+
operations = operations.each_slice(2).map do |op1, op2|
729+
# Combine two operations into one:
730+
# (((Remaining * x * A2 + B2) / C2) * x * A1 + B1) / C1
731+
# ((Remaining * (x*x) * (A2*A1) + (x*B2*A1+B1*C2)) / (C1*C2)
732+
# Therefore, combined operation can be represented as:
733+
# Anext = A2 * A1
734+
# Bnext = x * B2 * A1 + B1 * C2
735+
# Cnext = C1 * C2
736+
# xnext = x * x
737+
a1, b1, c1 = op1
738+
a2, b2, c2 = op2 || [[zero] * 4, [zero] * 4, BigDecimal(1)]
739+
[
740+
_bs_matrix_mult(a2, a1, 2, prec),
741+
_bs_weighted_sum(_bs_matrix_mult(b2, a1, 2, prec), xpow, b1, c2, prec),
742+
c1.mult(c2, prec),
743+
]
744+
end
745+
end
746+
_, sum_matrix, denominator = operations.first
747+
(sum_matrix[1] + f_a * (2 * a * sum_matrix[1] + sum_matrix[0])).div(denominator, prec)
748+
end
749+
750+
# Calculates asymptotic expansion of erfc(x)*exp(x**2)*sqrt(pi)/2 with binary splitting method
751+
private_class_method def _erfc_exp2_asymptotic_binary_splitting(x, prec) # :nodoc:
752+
# Let f(x) = erfc(x)*sqrt(pi)*exp(x**2)/2
753+
# f(x) satisfies the following differential equation:
754+
# 2*x*f(x) = f'(x) + 1
755+
# From the above equation, we can derive the following asymptotic expansion:
756+
# f(x) = (0..kmax).sum { (-1)**k * (2*k)! / 4**k / k! / x**(2*k)) } / x
757+
758+
# This asymptotic expansion does not converge.
759+
# But if there is a k that satisfies (2*k)! / 4**k / k! / x**(2*k) < 10**(-prec),
760+
# It is enough to calculate erfc within the given precision.
761+
# Using Stirling's approximation, we can simplify this condition to:
762+
# sqrt(2)/2 + k*log(k) - k - 2*k*log(x) < -prec*log(10)
763+
# and the left side is minimized when k = x**2.
764+
xf = x.to_f
765+
kmax = (1..(xf ** 2).floor).bsearch do |k|
766+
Math.log(2) / 2 + k * Math.log(k) - k - 2 * k * Math.log(xf) < -prec * Math.log(10)
767+
end
768+
return unless kmax
769+
770+
# Convert asymptotic expansion to nested form:
771+
# 1 + a/x + a*b/x/x + a*b*c/x/x/x + a*b*c/x/x/x*rest
772+
# = 1 + (a/x) * (1 + (b/x) * (1 + (c/x) * (1 + rest)))
773+
#
774+
# And calculate it with binary splitting:
775+
# (a1/d + b1/d * (a2/d + b2/d * (rest)))
776+
# = ((a1*d+b1*a2)/(d*d) + b1*b2/(d*denominator) * (rest)))
777+
denominator = x.mult(x, prec).mult(2, prec)
778+
fractions = (1..kmax).map do |k|
779+
[denominator, BigDecimal(1 - 2 * k)]
780+
end
781+
while fractions.size > 1
782+
fractions = fractions.each_slice(2).map do |fraction1, fraction2|
783+
a1, b1 = fraction1
784+
a2, b2 = fraction2 || [BigDecimal(0), denominator]
785+
[
786+
a1.mult(denominator, prec).add(b1.mult(a2, prec), prec),
787+
b1.mult(b2, prec),
788+
]
789+
end
790+
denominator = denominator.mult(denominator, prec)
791+
end
792+
sum = fractions[0][0].add(fractions[0][1], prec).div(denominator, prec)
793+
sum.div(x, prec) / 2
794+
end
795+
796+
# Calculates f(1/(a + x)) where f(x) = (sqrt(pi)/2) * exp(1/x**2) * erfc(1/x)
797+
# f(1/(a+x)) = f(1/a - x/(a*(a+x)))
798+
private_class_method def _erfc_exp2_inv_inv_binary_splitting(x, a, f_inva, prec)
799+
return f_inva if x.zero?
800+
# f(x) satisfies the following differential equation:
801+
# (1/a+w)**3*f'(1/a+w) + 2*f(1/a+w) = 1/a + w
802+
# From the above equation, we can derive the following Taylor expansion around x=a:
803+
# Coefficients: f(1/a + w) = c0 + c1*w + c2*w**2 + c3*w**3 + ...
804+
# Constraints:
805+
# (w**3 + 3*w**2/a + 3*w/a**2 + 1/a**3) * (c1 + 2*c2*w + 3*c3*w**2 + 4*c4*w**3 + ...)
806+
# + 2 * (c0 + c1*w + c2*w**2 + c3*w**3 + ...) = 1/a + w
807+
# Recurrence relations:
808+
# c0 = f(1/a)
809+
# c1 = a**2 - 2*c0*a**3
810+
# c2 = (a**3 - 3*c1*a - 2*c1*a**3) / 2
811+
# c3 = -(3*c1*a**2 + 6*c2*a + 2*c2*a**3) / 3
812+
# c(n) = -((n-3)*c(n-3)*a**3 + 3*(n-2)*c(n-2)*a**2 + 3*(n-1)*c(n-1)*a + 2*c(n-1)*a**3) / n
813+
814+
aa = a.mult(a, prec)
815+
aaa = aa.mult(a, prec)
816+
c0 = f_inva
817+
c1 = (aa - 2 * c0 * aaa).mult(1, prec)
818+
c2 = (aaa - 3 * c1 * a - 2 * c1 * aaa).div(2, prec)
819+
820+
# Estimate the number of steps needed to achieve the required precision
821+
low_prec = 16
822+
w = x.div(a.mult(a + x, low_prec), low_prec)
823+
wpow = w.mult(w, low_prec)
824+
cm3, cm2, cm1 = [c0, c1, c2].map {|v| v.mult(1, low_prec) }
825+
a_low, aa_low, aaa_low = [a, aa, aaa].map {|v| v.mult(1, low_prec) }
826+
step = (3..).find do |n|
827+
wpow = wpow.mult(w, low_prec)
828+
cn = -((n - 3) * cm3 * aaa_low + 3 * aa_low * (n - 2) * cm2 + 3 * a_low * (n - 1) * cm1 + 2 * cm1 * aaa_low).div(n, low_prec)
829+
cm3, cm2, cm1 = cm2, cm1, cn
830+
cn.mult(wpow, low_prec).exponent < -prec
831+
end
832+
833+
# Let M(n) be a 3x3 matrix that transforms (c(n),c(n+1),c(n+2)) to (c(n-1),c(n),c(n+1))
834+
# Mn = | 0 1 0 |
835+
# | 0 0 1 |
836+
# | -(n-3)*aaa/n -3*(n-2)*aa/n -2*aaa-3*(n-1)*a/n |
837+
# Vector(c6,c7,c8) = M6*M5*M4*M3*M2*M1 * Vector(c0,c1,c2)
838+
# Vector(c0+c1*y/z+c2*(y/z)**2+..., _, _) = (((... + I)*M3*y/z + I)*M2*y/z + I)*M1*y/z + I) * Vector(c2, c1, c0)
839+
# Perform binary splitting on this nested parenthesized calculation by using the following formula:
840+
# (((...)*A2*y/z + B2)/D2 * A1*y/z + B1)/D1 = (((...)*(A2*A1)*(y*y)/z + (B2*A1*y+z*D2*B1)) / (D1*D2*z)
841+
# where A_n, Bn are matrices and Dn are scalars
842+
843+
zero = BigDecimal(0)
844+
operations = (3..step + 2).map do |n|
845+
bign = BigDecimal(n)
846+
[
847+
[
848+
zero, bign, zero,
849+
zero, zero, bign,
850+
BigDecimal(-(n - 3) * aaa), -3 * (n - 2) * aa, -2 * aaa - 3 * (n - 1) * a
851+
],
852+
[bign, zero, zero, zero, bign, zero, zero, zero, bign],
853+
bign
854+
]
855+
end
856+
857+
z = a.mult(a + x, prec)
858+
while operations.size > 1
859+
y = y ? y.mult(y, prec) : -x.mult(1, prec)
860+
operations = operations.each_slice(2).map do |op1, op2|
861+
a1, b1, d1 = op1
862+
a2, b2, d2 = op2 || [[zero] * 9, [zero] * 9, BigDecimal(1)]
863+
[
864+
_bs_matrix_mult(a2, a1, 3, prec),
865+
_bs_weighted_sum(_bs_matrix_mult(b2, a1, 3, prec), y, b1, d2.mult(z, prec), prec),
866+
d1.mult(d2, prec).mult(z, prec),
867+
]
868+
end
869+
end
870+
_, sum_matrix, denominator = operations[0]
871+
(sum_matrix[0] * c0 + sum_matrix[1] * c1 + sum_matrix[2] * c2).div(denominator, prec)
872+
end
579873

580874
# call-seq:
581875
# PI(numeric) -> BigDecimal
@@ -588,38 +882,18 @@ def expm1(x, prec)
588882
#
589883
def PI(prec)
590884
prec = BigDecimal::Internal.coerce_validate_prec(prec, :PI)
591-
n = prec + BigDecimal.double_fig
592-
zero = BigDecimal("0")
593-
one = BigDecimal("1")
594-
two = BigDecimal("2")
595-
596-
m25 = BigDecimal("-0.04")
597-
m57121 = BigDecimal("-57121")
598-
599-
pi = zero
600-
601-
d = one
602-
k = one
603-
t = BigDecimal("-80")
604-
while d.nonzero? && ((m = n - (pi.exponent - d.exponent).abs) > 0)
605-
m = BigDecimal.double_fig if m < BigDecimal.double_fig
606-
t = t*m25
607-
d = t.div(k,m)
608-
k = k+two
609-
pi = pi + d
610-
end
611-
612-
d = one
613-
k = one
614-
t = BigDecimal("956")
615-
while d.nonzero? && ((m = n - (pi.exponent - d.exponent).abs) > 0)
616-
m = BigDecimal.double_fig if m < BigDecimal.double_fig
617-
t = t.div(m57121,n)
618-
d = t.div(k,m)
619-
pi = pi + d
620-
k = k+two
885+
n = prec + BigDecimal.double_fig
886+
a = BigDecimal(1)
887+
b = BigDecimal(0.5, 0).sqrt(n)
888+
s = BigDecimal(0.25, 0)
889+
t = 1
890+
while a != b && (a - b).exponent > 1 - n
891+
c = (a - b).div(2, n)
892+
a, b = (a + b).div(2, n), (a * b).sqrt(n)
893+
s = s.sub(c * c * t, n)
894+
t *= 2
621895
end
622-
pi.mult(1, prec)
896+
(a * b).div(s, prec)
623897
end
624898

625899
# call-seq:

0 commit comments

Comments
 (0)