Skip to content

Commit 9da744d

Browse files
committed
add type param to Weeks{T}
T may be a complex or real type, depending on the return type desired. This implements an idea from github user elisno
1 parent 9b2c5d8 commit 9da744d

File tree

2 files changed

+56
-29
lines changed

2 files changed

+56
-29
lines changed

src/weeks.jl

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,32 @@ end
1010

1111
### Weeks
1212

13-
mutable struct Weeks <: AbstractWeeks
13+
mutable struct Weeks{T} <: AbstractWeeks
1414
func::Function
1515
Nterms::Int
1616
sigma::Float64
1717
b::Float64
18-
coefficients::Array{Float64,1}
18+
coefficients::Array{T,1}
1919
end
2020

21-
function _get_coefficients(func, Nterms, sigma, b)
22-
a0 = real(_wcoeff(func, Nterms, sigma, b))
21+
function _get_coefficients(func, Nterms, sigma, b, ::Type{T}) where T <:Number
22+
a0 = _wcoeff(func, Nterms, sigma, b, T)
2323
return a0[Nterms+1:2*Nterms]
2424
end
2525

26-
_get_coefficients(w::Weeks) = _get_coefficients(w.func, w.Nterms, w.sigma, w.b)
26+
_get_coefficients(w::Weeks{T}) where T <: Number = _get_coefficients(w.func, w.Nterms, w.sigma, w.b, T)
2727
_set_coefficients(w::Weeks) = (w.coefficients = _get_coefficients(w))
2828

2929
const weeks_default_num_terms = 64
3030

3131
"""
32-
w::Weeks = Weeks(func::Function, Nterms::Integer=64, sigma=1.0, b=1.0)
32+
w::Weeks = Weeks(func::Function, Nterms::Integer=64, sigma=1.0, b=1.0; datatype=Float64)
3333
34-
return `w`, which estimates the inverse Laplace transform of `func` with
34+
Return `w`, which estimates the inverse Laplace transform of `func` with
3535
the Weeks algorithm. `w(t)` evaluates the transform at `t`. The accuracy depends on the choice
36-
of `sigma` and `b`, with the optimal choices depending on `t`.
36+
of `sigma` and `b`, with the optimal choices depending on `t`. `datatype` should agree with
37+
the `DataType` returned by `func`. For convenience, `datatype=Complex` is equivalent to
38+
`datatype=Complex{Float64}`
3739
3840
The call to `Weeks` that creates `w` is expensive relative to evaluation via `w(t)`.
3941
@@ -47,8 +49,11 @@ julia> ft(pi/2)
4749
0.0
4850
```
4951
"""
50-
Weeks(func::Function, Nterms::Integer=weeks_default_num_terms,
51-
sigma=1.0, b=1.0) = Weeks(func, Nterms, sigma, b, _get_coefficients(func, Nterms, sigma, b))
52+
function Weeks(func::Function, Nterms::Integer=weeks_default_num_terms,
53+
sigma=1.0, b=1.0; datatype=Float64)
54+
outdatatype = datatype == Complex ? Complex{Float64} : datatype # allow `Complex` as abbrev for Complex{Float64}
55+
return Weeks{outdatatype}(func, Nterms, sigma, b, _get_coefficients(func, Nterms, sigma, b, outdatatype))
56+
end
5257

5358
function eval_ilt(w::Weeks, t)
5459
L = _laguerre(w.coefficients, 2 * w.b * t)
@@ -127,43 +132,43 @@ end
127132

128133
#### WeeksErr
129134

130-
mutable struct WeeksErr <: AbstractWeeks
135+
mutable struct WeeksErr{T} <: AbstractWeeks
131136
func::Function
132137
Nterms::Int
133138
sigma::Float64
134139
b::Float64
135-
coefficients::Array{Float64,1}
140+
coefficients::Array{T,1}
136141
sa1::Float64
137142
sa2::Float64
138143
end
139144

140-
function _get_coefficients_and_params(func, Nterms, sigma, b)
141-
M = 2 * Nterms
142-
a0 = real(_wcoeff(func,M,sigma,b))
145+
function _get_coefficients_and_params(func, Nterms, sigma, b, ::Type{T}) where T
146+
M = 2 * Nterms # why 2 * Nterms ?
147+
a0 = _wcoeff(func, M, sigma, b, T)
143148
a1 = a0[2*Nterms+1:3*Nterms]
144149
sa1 = sum(abs.(a1))
145150
sa2 = sum(abs.(@view a0[3*Nterms+1:4*Nterms]))
146151
return (a1,sa1,sa2)
147152
end
148153

149-
_get_coefficients_and_params(w::WeeksErr) = _get_coefficients_and_params(w.func, w.Nterms, w.sigma, w.b)
150-
151-
_get_coefficients(w::WeeksErr) = _get_coefficients_and_params(w.func, w.Nterms, w.sigma, w.b)
152-
154+
_get_coefficients_and_params(w::WeeksErr{T}) where T = _get_coefficients_and_params(w.func, w.Nterms, w.sigma, w.b, T)
155+
_get_coefficients(w::WeeksErr{T}) where T = _get_coefficients_and_params(w.func, w.Nterms, w.sigma, w.b, T)
153156
_set_coefficients(w::WeeksErr) = (w.coefficients, w.sa1, w.sa2) = _get_coefficients_and_params(w)
154157

158+
# FIXME: magic numbers here
155159
function optimize(w::WeeksErr, t)
156160
(w.sigma, w.b) = _optimize_sigma_and_b(w.func, t, w.Nterms, 0.0, 30, 30)
157161
_set_coefficients(w)
158162
return w
159163
end
160164

161165
"""
162-
w::WeeksErr = WeeksErr(func::Function, Nterms::Integer=64, sigma=1.0, b=1.0)
166+
w::WeeksErr = WeeksErr(func::Function, Nterms::Integer=64, sigma=1.0, b=1.0; datatype=Float64)
163167
164-
return `w`, which estimates the inverse Laplace transform of `func` via the Weeks algorithm.
168+
Return `w`, which estimates the inverse Laplace transform of `func` via the Weeks algorithm.
165169
`w(t)` returns a tuple containing the inverse transform at `t` and an error estimate. The accuracy of the
166-
inversion depends on the choice of `sigma` and `b`.
170+
inversion depends on the choice of `sigma` and `b`. See the documentation for `Weeks` for a
171+
description of the parameter `datatype`.
167172
168173
# Example
169174
@@ -186,9 +191,10 @@ julia> ft(pi/2)[1] - cospi(1/2) # cospi is more accurate
186191
0.0
187192
```
188193
"""
189-
function WeeksErr(func::Function, Nterms::Integer=weeks_default_num_terms, sigma=1.0, b=1.0)
190-
params = _get_coefficients_and_params(func, Nterms, sigma, b)
191-
return WeeksErr(func,Nterms,sigma,b,params...)
194+
function WeeksErr(func::Function, Nterms::Integer=weeks_default_num_terms, sigma=1.0, b=1.0; datatype=Float64)
195+
outdatatype = datatype == Complex ? Complex{Float64} : datatype # allow `Complex` as abbrev for Complex{Float64}
196+
params = _get_coefficients_and_params(func, Nterms, sigma, b, outdatatype)
197+
return WeeksErr{outdatatype}(func, Nterms, sigma, b, params...)
192198
end
193199

194200
function eval_ilt(w::WeeksErr, t)
@@ -211,6 +217,9 @@ end
211217

212218
##### internal functions
213219

220+
_wcoeff(F, N, sig, b, ::Type{T}) where T <: Real = real(_wcoeff(F, N, sig, b))
221+
_wcoeff(F, N, sig, b, ::Type{T}) where T <: Complex = _wcoeff(F, N, sig, b)
222+
214223
function _wcoeff(F, N, sig, b)
215224
n = -N:N-1 # FIXME: remove 1 and test
216225
h = pi / N # FIXME: what data type ?

test/weeks_test.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
@test typeof(Weeks(s -> 1 / s)) == Weeks
2-
@test typeof(WeeksErr(s -> 1 / s)) == WeeksErr
1+
@test isa(Weeks(s -> 1 / s), Weeks)
2+
@test isa(WeeksErr(s -> 1 / s), WeeksErr)
33

44
fl = Weeks(s -> 1 / s^2)
55

@@ -35,9 +35,27 @@ e2 = abs(fle(10.0)[1] - cos(10.0))
3535

3636
fle = WeeksErr(s -> s/(1+s^2), 64)
3737
c1 = copy(fle.coefficients)
38-
@test string(fle) == "WeeksErr(Nterms=64,sigma=1.0,b=1.0)"
38+
@test string(fle) == "WeeksErr{Float64}(Nterms=64,sigma=1.0,b=1.0)"
3939
setparameters(fle,2.0,2.0,80)
40-
@test string(fle) == "WeeksErr(Nterms=80,sigma=2.0,b=2.0)"
40+
@test string(fle) == "WeeksErr{Float64}(Nterms=80,sigma=2.0,b=2.0)"
4141

4242
# Check that Laguerre coefficients have been recomputed
4343
@test c1 != fle.coefficients
44+
45+
### Complex
46+
47+
function Fcomplex(s)
48+
# Laplace domain
49+
α = complex(-0.3, 6.0)
50+
return 1 / (s - α)
51+
end
52+
53+
function fcomplex(t)
54+
# Time domain
55+
α = complex(-0.3, 6.0)
56+
return exp* t)
57+
end
58+
59+
let Fc = Weeks(Fcomplex, 1024, datatype=Complex), trange = range(0.0, stop=30.0, length=1000)
60+
@test isapprox(Fc.(trange), fcomplex.(trange), atol=0.001)
61+
end

0 commit comments

Comments
 (0)