|
1 | 1 | using SpecialFunctions |
2 | 2 | import Base.Broadcast |
3 | 3 |
|
4 | | -const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj] |
5 | 4 |
|
6 | | -const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh] |
| 5 | +const linearity_known_1 = IdDict{Function,Bool}() |
| 6 | +const linearity_known_2 = IdDict{Function,Bool}() |
| 7 | + |
| 8 | +const linearity_map_1 = IdDict{Function, Bool}() |
| 9 | +const linearity_map_2 = IdDict{Function, Tuple{Bool, Bool, Bool}}() |
7 | 10 |
|
8 | | -diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal, max, min, convert] |
9 | | -diadic_of_linearity(::Val{(true, true, false)}) = [*] |
10 | | -#diadic_of_linearit(::(Val{(true, false, true)}) = [besselk, hankelh2, bessely, besselj, besseli, polygamma, hankelh1] |
11 | | -diadic_of_linearity(::Val{(true, false, false)}) = [/] |
12 | | -diadic_of_linearity(::Val{(false, true, false)}) = [\] |
13 | | -diadic_of_linearity(::Val{(false, false, false)}) = [hypot, atan, mod, rem, lbeta, ^, beta] |
14 | | -diadic_of_linearity(::Val) = [] |
| 11 | +# 1-arg |
| 12 | + |
| 13 | +const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj] |
15 | 14 |
|
16 | | -haslinearity(f, nargs) = false |
| 15 | +const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh] |
17 | 16 |
|
18 | | -# linearity of a single input function is either |
19 | | -# Val{true}() or Val{false}() |
20 | | -# |
| 17 | +# We store 3 bools even for 1-arg functions for type stability |
| 18 | +const three_trues = (true, true, true) |
21 | 19 | for f in monadic_linear |
22 | | - @eval begin |
23 | | - haslinearity(::typeof($f), ::Val{1}) = true |
24 | | - linearity(::typeof($f), ::Val{1}) = Val{true}() |
25 | | - end |
| 20 | + linearity_known_1[f] = true |
| 21 | + linearity_map_1[f] = true |
26 | 22 | end |
27 | | -# linearity of a 2-arg function is: |
28 | | -# Val{(linear11, linear22, linear12)}() |
29 | | -# |
30 | | -# linearIJ refers to the zeroness of d^2/dxIxJ |
| 23 | + |
31 | 24 | for f in monadic_nonlinear |
32 | | - @eval begin |
33 | | - haslinearity(::typeof($f), ::Val{1}) = true |
34 | | - linearity(::typeof($f), ::Val{1}) = Val{false}() |
35 | | - end |
| 25 | + linearity_known_1[f] = true |
| 26 | + linearity_map_1[f] = false |
36 | 27 | end |
37 | 28 |
|
38 | | -for linearity_mask = 0:2^3-1 |
39 | | - lin = Val{map(x->x!=0, (linearity_mask & 4, |
40 | | - linearity_mask & 2, |
41 | | - linearity_mask & 1))}() |
| 29 | +# 2-arg |
| 30 | +for f in [+, rem2pi, -, >, isless, <, isequal, max, min, convert] |
| 31 | + linearity_known_2[f] = true |
| 32 | + linearity_map_2[f] = (true, true, true) |
| 33 | +end |
42 | 34 |
|
43 | | - for f in diadic_of_linearity(lin) |
44 | | - @eval begin |
45 | | - haslinearity(::typeof($f), ::Val{2}) = true |
46 | | - linearity(::typeof($f), ::Val{2}) = $lin |
47 | | - end |
48 | | - end |
| 35 | +for f in [*] |
| 36 | + linearity_known_2[f] = true |
| 37 | + linearity_map_2[f] = (true, true, false) |
| 38 | +end |
| 39 | + |
| 40 | +for f in [/] |
| 41 | + linearity_known_2[f] = true |
| 42 | + linearity_map_2[f] = (true, false, false) |
| 43 | +end |
| 44 | +for f in [\] |
| 45 | + linearity_known_2[f] = true |
| 46 | + linearity_map_2[f] = (false, true, false) |
49 | 47 | end |
50 | 48 |
|
| 49 | +for f in [hypot, atan, mod, rem, lbeta, ^, beta] |
| 50 | + linearity_known_2[f] = true |
| 51 | + linearity_map_2[f] = (false, false, false) |
| 52 | +end |
| 53 | + |
| 54 | +haslinearity_1(@nospecialize(f)) = get(linearity_known_1, f, false) |
| 55 | +haslinearity_2(@nospecialize(f)) = get(linearity_known_2, f, false) |
| 56 | + |
| 57 | +linearity_1(@nospecialize(f)) = linearity_map_1[f] |
| 58 | +linearity_2(@nospecialize(f)) = linearity_map_2[f] |
| 59 | + |
51 | 60 | # TermCombination datastructure |
52 | 61 |
|
53 | 62 | struct TermCombination |
@@ -151,11 +160,10 @@ function _sparse(t::TermCombination, n) |
151 | 160 | end |
152 | 161 |
|
153 | 162 | # 1-arg functions |
154 | | -combine_terms(::Val{true}, term) = term |
155 | | -combine_terms(::Val{false}, term) = term * term |
| 163 | +combine_terms_1(lin, term) = lin ? term : term * term |
156 | 164 |
|
157 | 165 | # 2-arg functions |
158 | | -function combine_terms(::Val{linearity}, term1, term2) where linearity |
| 166 | +function combine_terms_2(linearity, term1, term2) |
159 | 167 |
|
160 | 168 | linear11, linear22, linear12 = linearity |
161 | 169 | term = zero(TermCombination) |
|
0 commit comments