Skip to content

Commit 197b023

Browse files
Merge pull request #550 from SciML/s/dict-linearity
Dict-based linearity
2 parents 3e621ee + ad11b1f commit 197b023

File tree

3 files changed

+51
-43
lines changed

3 files changed

+51
-43
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Requires = "1.0"
4747
SafeTestsets = "0.0.1"
4848
SpecialFunctions = "0.7, 0.8, 0.9, 0.10"
4949
StaticArrays = "0.10, 0.11, 0.12"
50-
SymbolicUtils = "0.4.3"
50+
SymbolicUtils = "0.5"
5151
TreeViews = "0.3"
5252
UnPack = "0.1, 1.0"
5353
Unitful = "1.1"

src/direct.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,17 @@ let
117117
@rule +(~~xs) => reduce(+, filter(isidx, ~~xs), init=_scalar)
118118
@rule *(~~xs) => reduce(*, filter(isidx, ~~xs), init=_scalar)
119119
@rule (~f)(~x::(!isidx)) => _scalar
120-
@rule (~f)(~x::isidx) => if haslinearity(~f, Val{1}())
121-
combine_terms(linearity(~f, Val{1}()), ~x)
120+
@rule (~f)(~x::isidx) => if haslinearity_1(~f)
121+
combine_terms_1(linearity_1(~f), ~x)
122122
else
123123
error("Function of unknown linearity used: ", ~f)
124124
end
125125
@rule (^)(~x::isidx, ~y) => ~y isa Number && isone(~y) ? ~x : (~x) * (~x)
126126
@rule (~f)(~x, ~y) => begin
127-
if haslinearity(~f, Val{2}())
127+
if haslinearity_2(~f)
128128
a = isidx(~x) ? ~x : _scalar
129129
b = isidx(~y) ? ~y : _scalar
130-
combine_terms(linearity(~f, Val{2}()), a, b)
130+
combine_terms_2(linearity_2(~f), a, b)
131131
else
132132
error("Function of unknown linearity used: ", ~f)
133133
end

src/linearity.jl

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,62 @@
11
using SpecialFunctions
22
import Base.Broadcast
33

4-
const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj]
54

6-
const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh]
5+
const linearity_known_1 = IdDict{Function,Bool}()
6+
const linearity_known_2 = IdDict{Function,Bool}()
7+
8+
const linearity_map_1 = IdDict{Function, Bool}()
9+
const linearity_map_2 = IdDict{Function, Tuple{Bool, Bool, Bool}}()
710

8-
diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal, max, min, convert]
9-
diadic_of_linearity(::Val{(true, true, false)}) = [*]
10-
#diadic_of_linearit(::(Val{(true, false, true)}) = [besselk, hankelh2, bessely, besselj, besseli, polygamma, hankelh1]
11-
diadic_of_linearity(::Val{(true, false, false)}) = [/]
12-
diadic_of_linearity(::Val{(false, true, false)}) = [\]
13-
diadic_of_linearity(::Val{(false, false, false)}) = [hypot, atan, mod, rem, lbeta, ^, beta]
14-
diadic_of_linearity(::Val) = []
11+
# 1-arg
12+
13+
const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj]
1514

16-
haslinearity(f, nargs) = false
15+
const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh]
1716

18-
# linearity of a single input function is either
19-
# Val{true}() or Val{false}()
20-
#
17+
# We store 3 bools even for 1-arg functions for type stability
18+
const three_trues = (true, true, true)
2119
for f in monadic_linear
22-
@eval begin
23-
haslinearity(::typeof($f), ::Val{1}) = true
24-
linearity(::typeof($f), ::Val{1}) = Val{true}()
25-
end
20+
linearity_known_1[f] = true
21+
linearity_map_1[f] = true
2622
end
27-
# linearity of a 2-arg function is:
28-
# Val{(linear11, linear22, linear12)}()
29-
#
30-
# linearIJ refers to the zeroness of d^2/dxIxJ
23+
3124
for f in monadic_nonlinear
32-
@eval begin
33-
haslinearity(::typeof($f), ::Val{1}) = true
34-
linearity(::typeof($f), ::Val{1}) = Val{false}()
35-
end
25+
linearity_known_1[f] = true
26+
linearity_map_1[f] = false
3627
end
3728

38-
for linearity_mask = 0:2^3-1
39-
lin = Val{map(x->x!=0, (linearity_mask & 4,
40-
linearity_mask & 2,
41-
linearity_mask & 1))}()
29+
# 2-arg
30+
for f in [+, rem2pi, -, >, isless, <, isequal, max, min, convert]
31+
linearity_known_2[f] = true
32+
linearity_map_2[f] = (true, true, true)
33+
end
4234

43-
for f in diadic_of_linearity(lin)
44-
@eval begin
45-
haslinearity(::typeof($f), ::Val{2}) = true
46-
linearity(::typeof($f), ::Val{2}) = $lin
47-
end
48-
end
35+
for f in [*]
36+
linearity_known_2[f] = true
37+
linearity_map_2[f] = (true, true, false)
38+
end
39+
40+
for f in [/]
41+
linearity_known_2[f] = true
42+
linearity_map_2[f] = (true, false, false)
43+
end
44+
for f in [\]
45+
linearity_known_2[f] = true
46+
linearity_map_2[f] = (false, true, false)
4947
end
5048

49+
for f in [hypot, atan, mod, rem, lbeta, ^, beta]
50+
linearity_known_2[f] = true
51+
linearity_map_2[f] = (false, false, false)
52+
end
53+
54+
haslinearity_1(@nospecialize(f)) = get(linearity_known_1, f, false)
55+
haslinearity_2(@nospecialize(f)) = get(linearity_known_2, f, false)
56+
57+
linearity_1(@nospecialize(f)) = linearity_map_1[f]
58+
linearity_2(@nospecialize(f)) = linearity_map_2[f]
59+
5160
# TermCombination datastructure
5261

5362
struct TermCombination
@@ -151,11 +160,10 @@ function _sparse(t::TermCombination, n)
151160
end
152161

153162
# 1-arg functions
154-
combine_terms(::Val{true}, term) = term
155-
combine_terms(::Val{false}, term) = term * term
163+
combine_terms_1(lin, term) = lin ? term : term * term
156164

157165
# 2-arg functions
158-
function combine_terms(::Val{linearity}, term1, term2) where linearity
166+
function combine_terms_2(linearity, term1, term2)
159167

160168
linear11, linear22, linear12 = linearity
161169
term = zero(TermCombination)

0 commit comments

Comments
 (0)