Skip to content

Commit e6dffbd

Browse files
committed
Add complex VML functions
1 parent 3692c75 commit e6dffbd

File tree

7 files changed

+244
-174
lines changed

7 files changed

+244
-174
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ regarding these options is available on
3333

3434
![VML Performance Comparison](/benchmark/performance.png)
3535

36+
![VML Complex Performance Comparison](/benchmark/performance_complex.png)
37+
3638
Tests were performed on an Intel(R) Core(TM) i7-3930K CPU. Error bars
3739
are 95% confidence intervals based on 25 repetitions of each test with
3840
a 1,000,000 element vector. The dashed line indicates equivalent

benchmark/benchmark.jl

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,29 @@ using Distributions, PyCall, PyPlot
22
@pyimport matplotlib.gridspec as gridspec
33

44
include(joinpath(dirname(dirname(@__FILE__)), "test", "common.jl"))
5+
complex = !isempty(ARGS) && ARGS[1] == "complex"
56

6-
function bench(fns, input, nrep)
7+
function bench(fns, input)
78
[t=>begin
8-
times = Array(Float64, nrep, length(fns))
9+
times = Array(Vector{Float64}, length(fns))
910
for ifn = 1:length(fns)
1011
fn = fns[ifn]
1112
inp = input[t][ifn]
12-
for irep = 1:nrep
13+
fn(inp...)
14+
gc()
15+
nrep = max(iceil(2/(@elapsed (gc_disable(); fn(inp...); gc_enable(); gc()))), 3)
16+
println("Running $nrep reps of $fn($t)")
17+
@time times[ifn] = [begin
18+
gc()
1319
gc_disable()
14-
times[irep, ifn] = @elapsed fn(inp...)
20+
time = @elapsed fn(inp...)
1521
gc_enable()
16-
end
17-
gc()
22+
time
23+
end for i = 1:nrep]
24+
# println((mean(times[ifn]), std(times[ifn])))
1825
end
1926
times
20-
end for t in (Float32, Float64)]
27+
end for t in types]
2128
end
2229

2330
function ratioci(y, x, alpha=0.05)
@@ -33,34 +40,33 @@ end
3340

3441
# First generate some random data and test functions in Base on it
3542
const NVALS = 1_000_000
43+
base_unary = complex ? base_unary_complex : base_unary_real
44+
base_binary = complex ? base_binary_complex : base_binary_real
45+
types = complex ? (Complex64, Complex128) : (Float32, Float64)
3646
input = [t=>[[(randindomain(t, NVALS, domain),) for (fn, domain) in base_unary];
3747
[(randindomain(t, NVALS, domain1), randindomain(t, NVALS, domain2))
3848
for (fn, domain1, domain2) in base_binary];
3949
(randindomain(t, NVALS, (0, 100)), randindomain(t, 1, (-1, 20))[1])]
40-
for t in (Float32, Float64)]
41-
fns = [[x[1] for x in base_unary]; [x[1] for x in base_binary]; .^]
50+
for t in types]
51+
fns = [[x[1] for x in base_unary]; [x[1] for x in base_binary]; (complex ? [] : .^)]
4252

43-
bench(fns, input, 1)
44-
builtin = bench(fns, input, 25)
53+
builtin = bench(fns, input)
4554

4655
# Now with VML
4756
using VML
48-
#vml_set_accuracy(VML_LA)
4957

50-
bench(fns, input, 1)
51-
vml = bench(fns, input, 25)
58+
vml = bench(fns, input)
5259

5360
# Print ratio
5461
clf()
55-
types = (Float32, Float64)
5662
colors = ["r", "y"]
5763
for itype = 1:length(types)
5864
builtint = builtin[types[itype]]
5965
vmlt = vml[types[itype]]
60-
μ = vec(mean(builtint, 1)./mean(vmlt, 1))
66+
μ = vec(map(mean, builtint)./map(mean, vmlt))
6167
ci = zeros(Float64, 2, length(fns))
62-
for ifn = 1:size(builtint, 2)
63-
lower, upper = ratioci(builtint[:, ifn], vmlt[:, ifn])
68+
for ifn = 1:length(builtint)
69+
lower, upper = ratioci(builtint[ifn], vmlt[ifn])
6470
ci[1, ifn] = μ[ifn] - lower
6571
ci[2, ifn] = upper - μ[ifn]
6672
end
@@ -69,11 +75,13 @@ end
6975
ax = gca()
7076
ax[:set_xlim](0, length(fns)+1)
7177
fname = [string(fn.env.name) for fn in fns]
72-
fname[end-1] = "A.^B"
73-
fname[end] = "A.^b"
78+
if !complex
79+
fname[end-1] = "A.^B"
80+
fname[end] = "A.^b"
81+
end
7482
xticks(1:length(fns)+1, fname, rotation=70, fontsize=10)
7583
title("VML Performance")
7684
ylabel("Relative Speed (Base/VML)")
77-
legend(("Float32", "Float64"))
85+
legend([string(x) for x in types])
7886
ax[:axhline](1; color="black", linestyle="--")
79-
savefig("performance.png")
87+
savefig("performance$(complex ? "_complex" : "").png")

benchmark/performance.png

-5.34 KB
Loading

benchmark/performance_complex.png

29.8 KB
Loading

src/VML.jl

Lines changed: 131 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -34,103 +34,125 @@ function vml_check_error()
3434
end
3535
end
3636

37-
const unary_ops = [(:(Base.acos), :acos!, :Acos),
38-
(:(Base.asin), :asin!, :Asin),
39-
(:(Base.atan), :atan!, :Atan),
40-
(:(Base.cos), :cos!, :Cos),
41-
(:(Base.sin), :sin!, :Sin),
42-
(:(Base.tan), :tan!, :Tan),
43-
(:(Base.acosh), :acosh!, :Acosh),
44-
(:(Base.asinh), :asinh!, :Asinh),
45-
(:(Base.atanh), :atanh!, :Atanh),
46-
(:(Base.cosh), :cosh!, :Cosh),
47-
(:(Base.sinh), :sinh!, :Sinh),
48-
(:(Base.tanh), :tanh!, :Tanh),
49-
(:(Base.cbrt), :cbrt!, :Cbrt),
50-
(:(Base.sqrt), :sqrt!, :Sqrt),
51-
(:(Base.exp), :exp!, :Exp),
52-
(:(Base.expm1), :expm1!, :Expm1),
53-
(:(Base.log), :log!, :Ln),
54-
(:(Base.log10), :log10!, :Log10),
55-
(:(Base.log1p), :log1p, :Log1p),
56-
(:(Base.abs), :abs!, :Abs),
57-
(:(Base.abs2), :abs2!, :Sqr),
58-
(:(Base.ceil), :ceil!, :Ceil),
59-
(:(Base.floor), :floor!, :Floor),
60-
(:(Base.round), :round!, :Round),
61-
(:(Base.trunc), :trunc!, :Trunc),
62-
(:(Base.erf), :erf!, :Erf),
63-
(:(Base.erfc), :erfc!, :Erfc),
64-
(:(Base.erfinv), :erfinv!, :ErfInv),
65-
(:(Base.erfcinv), :erfcinv!, :ErfcInv),
66-
(:(Base.lgamma), :lgamma!, :LGamma),
67-
(:(Base.gamma), :gamma!, :TGamma),
68-
# Not in Base
69-
(:inv_cbrt, :inv_cbrt!, :InvCbrt),
70-
(:inv_sqrt, :inv_sqrt!, :InvSqrt),
71-
(:pow2o3, :pow2o3!, :Pow2o3),
72-
(:pow3o2, :pow3o2!, :Pow3o2)]
73-
74-
const binary_vector_ops = [(:(Base.atan2), :atan2!, :Atan2, false),
75-
(:(Base.hypot), :hypot!, :Hypot, false),
76-
(:(Base.(:.^)), :pow!, :Pow, true),
77-
(:(Base.(:./)), :divide!, :Div, true)]
78-
79-
for (prefix, t) in ((:_vmls, :Float32), (:_vmld, :Float64))
80-
# Unary
81-
for (jlname, jlname!, mklname) in unary_ops
82-
mklfn = Base.Meta.quot(symbol("$prefix$mklname"))
83-
exports = Symbol[]
84-
isa(jlname, Expr) || push!(exports, jlname)
85-
isa(jlname!, Expr) || push!(exports, jlname!)
86-
@eval begin
87-
$(isempty(exports) ? nothing : Expr(:export, exports...))
88-
function $(jlname!){N}(out::Array{$t,N}, A::Array{$t,N})
89-
size(out) == size(A) || throw(DimensionMismatch())
90-
ccall(($mklfn, lib), Void, (Int, Ptr{$t}, Ptr{$t}), length(A), A, out)
91-
vml_check_error()
92-
out
93-
end
94-
function $(jlname!)(A::Array{$t})
95-
ccall(($mklfn, lib), Void, (Int, Ptr{$t}, Ptr{$t}), length(A), A, A)
96-
vml_check_error()
97-
A
98-
end
99-
function $(jlname)(A::Array{$t})
100-
out = similar(A)
101-
ccall(($mklfn, lib), Void, (Int, Ptr{$t}, Ptr{$t}), length(A), A, out)
102-
vml_check_error()
103-
out
37+
function vml_prefix(t::DataType)
38+
if t == Float32
39+
return "_vmls"
40+
elseif t == Float64
41+
return "_vmld"
42+
elseif t == Complex{Float32}
43+
return "_vmlc"
44+
elseif t == Complex{Float64}
45+
return "_vmlz"
46+
end
47+
error("unknown type $t")
48+
end
49+
50+
function def_unary_op(tin, tout, jlname, jlname!, mklname)
51+
mklfn = Base.Meta.quot(symbol("$(vml_prefix(tin))$mklname"))
52+
exports = Symbol[]
53+
isa(jlname, Expr) || push!(exports, jlname)
54+
isa(jlname!, Expr) || push!(exports, jlname!)
55+
@eval begin
56+
$(isempty(exports) ? nothing : Expr(:export, exports...))
57+
function $(jlname!){N}(out::Array{$tout,N}, A::Array{$tin,N})
58+
size(out) == size(A) || throw(DimensionMismatch())
59+
ccall(($mklfn, lib), Void, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, out)
60+
vml_check_error()
61+
out
62+
end
63+
$(if tin == tout
64+
quote
65+
function $(jlname!)(A::Array{$tin})
66+
ccall(($mklfn, lib), Void, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, A)
67+
vml_check_error()
68+
A
69+
end
10470
end
71+
end)
72+
function $(jlname)(A::Array{$tin})
73+
out = similar(A, $tout)
74+
ccall(($mklfn, lib), Void, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, out)
75+
vml_check_error()
76+
out
10577
end
10678
end
79+
end
10780

108-
# Binary, two vectors
109-
for (jlname, jlname!, mklname, broadcast) in binary_vector_ops
110-
mklfn = Base.Meta.quot(symbol("$prefix$mklname"))
111-
exports = Symbol[]
112-
isa(jlname, Expr) || push!(exports, jlname)
113-
isa(jlname!, Expr) || push!(exports, jlname!)
114-
@eval begin
115-
$(isempty(exports) ? nothing : Expr(:export, exports...))
116-
function $(jlname!){N}(out::Array{$t,N}, A::Array{$t,N}, B::Array{$t,N})
117-
size(out) == size(A) == size(B) || $(broadcast ? :(return broadcast!($jlname, out, A, B)) : :(throw(DimensionMismatch())))
118-
ccall(($mklfn, lib), Void, (Int, Ptr{$t}, Ptr{$t}, Ptr{$t}), length(A), A, B, out)
119-
vml_check_error()
120-
out
121-
end
122-
function $(jlname){N}(A::Array{$t,N}, B::Array{$t,N})
123-
size(A) == size(B) || $(broadcast ? :(return broadcast($jlname, A, B)) : :(throw(DimensionMismatch())))
124-
out = similar(A)
125-
ccall(($mklfn, lib), Void, (Int, Ptr{$t}, Ptr{$t}, Ptr{$t}), length(A), A, B, out)
126-
vml_check_error()
127-
out
128-
end
81+
function def_binary_op(tin, tout, jlname, jlname!, mklname, broadcast)
82+
mklfn = Base.Meta.quot(symbol("$(vml_prefix(tin))$mklname"))
83+
exports = Symbol[]
84+
isa(jlname, Expr) || push!(exports, jlname)
85+
isa(jlname!, Expr) || push!(exports, jlname!)
86+
@eval begin
87+
$(isempty(exports) ? nothing : Expr(:export, exports...))
88+
function $(jlname!){N}(out::Array{$tout,N}, A::Array{$tin,N}, B::Array{$tin,N})
89+
size(out) == size(A) == size(B) || $(broadcast ? :(return broadcast!($jlname, out, A, B)) : :(throw(DimensionMismatch())))
90+
ccall(($mklfn, lib), Void, (Int, Ptr{$tin}, Ptr{$tin}, Ptr{$tout}), length(A), A, B, out)
91+
vml_check_error()
92+
out
93+
end
94+
function $(jlname){N}(A::Array{$tout,N}, B::Array{$tin,N})
95+
size(A) == size(B) || $(broadcast ? :(return broadcast($jlname, A, B)) : :(throw(DimensionMismatch())))
96+
out = similar(A)
97+
ccall(($mklfn, lib), Void, (Int, Ptr{$tin}, Ptr{$tin}, Ptr{$tout}), length(A), A, B, out)
98+
vml_check_error()
99+
out
129100
end
130101
end
102+
end
103+
104+
for t in (Float32, Float64, Complex64, Complex128)
105+
# Unary, real or complex
106+
def_unary_op(t, t, :(Base.acos), :acos!, :Acos)
107+
def_unary_op(t, t, :(Base.asin), :asin!, :Asin)
108+
def_unary_op(t, t, :(Base.acosh), :acosh!, :Acosh)
109+
def_unary_op(t, t, :(Base.asinh), :asinh!, :Asinh)
110+
def_unary_op(t, t, :(Base.sqrt), :sqrt!, :Sqrt)
111+
def_unary_op(t, t, :(Base.exp), :exp!, :Exp)
112+
def_unary_op(t, t, :(Base.log), :log!, :Ln)
131113

132-
# Binary, vector and scalar
133-
mklfn = Base.Meta.quot(symbol("$(prefix)Powx"))
114+
# Binary, real or complex
115+
def_binary_op(t, t, :(Base.(:.^)), :pow!, :Pow, true)
116+
def_binary_op(t, t, :(Base.(:./)), :divide!, :Div, true)
117+
end
118+
119+
for t in (Float32, Float64)
120+
# Unary, real-only
121+
def_unary_op(t, t, :(Base.cbrt), :cbrt!, :Cbrt)
122+
def_unary_op(t, t, :(Base.expm1), :expm1!, :Expm1)
123+
def_unary_op(t, t, :(Base.log1p), :log1p, :Log1p)
124+
def_unary_op(t, t, :(Base.abs), :abs!, :Abs)
125+
def_unary_op(t, t, :(Base.abs2), :abs2!, :Sqr)
126+
def_unary_op(t, t, :(Base.ceil), :ceil!, :Ceil)
127+
def_unary_op(t, t, :(Base.floor), :floor!, :Floor)
128+
def_unary_op(t, t, :(Base.round), :round!, :Round)
129+
def_unary_op(t, t, :(Base.trunc), :trunc!, :Trunc)
130+
def_unary_op(t, t, :(Base.erf), :erf!, :Erf)
131+
def_unary_op(t, t, :(Base.erfc), :erfc!, :Erfc)
132+
def_unary_op(t, t, :(Base.erfinv), :erfinv!, :ErfInv)
133+
def_unary_op(t, t, :(Base.erfcinv), :erfcinv!, :ErfcInv)
134+
def_unary_op(t, t, :(Base.lgamma), :lgamma!, :LGamma)
135+
def_unary_op(t, t, :(Base.gamma), :gamma!, :TGamma)
136+
# Not in Base
137+
def_unary_op(t, t, :inv_cbrt, :inv_cbrt!, :InvCbrt)
138+
def_unary_op(t, t, :inv_sqrt, :inv_sqrt!, :InvSqrt)
139+
def_unary_op(t, t, :pow2o3, :pow2o3!, :Pow2o3)
140+
def_unary_op(t, t, :pow3o2, :pow3o2!, :Pow3o2)
141+
142+
# Enabled only for Real. MKL guarantees higher accuracy, but at a
143+
# substantial performance cost.
144+
def_unary_op(t, t, :(Base.atan), :atan!, :Atan)
145+
def_unary_op(t, t, :(Base.cos), :cos!, :Cos)
146+
def_unary_op(t, t, :(Base.sin), :sin!, :Sin)
147+
def_unary_op(t, t, :(Base.tan), :tan!, :Tan)
148+
def_unary_op(t, t, :(Base.atanh), :atanh!, :Atanh)
149+
def_unary_op(t, t, :(Base.cosh), :cosh!, :Cosh)
150+
def_unary_op(t, t, :(Base.sinh), :sinh!, :Sinh)
151+
def_unary_op(t, t, :(Base.tanh), :tanh!, :Tanh)
152+
def_unary_op(t, t, :(Base.log10), :log10!, :Log10)
153+
154+
# .^ to scalar power
155+
mklfn = Base.Meta.quot(symbol("$(vml_prefix(t))Powx"))
134156
@eval begin
135157
export pow!
136158
function pow!{N}(out::Array{$t,N}, A::Array{$t,N}, b::$t)
@@ -146,6 +168,25 @@ for (prefix, t) in ((:_vmls, :Float32), (:_vmld, :Float64))
146168
out
147169
end
148170
end
171+
172+
# Binary, real-only
173+
def_binary_op(t, t, :(Base.atan2), :atan2!, :Atan2, false)
174+
def_binary_op(t, t, :(Base.hypot), :hypot!, :Hypot, false)
175+
176+
# Unary, complex-only
177+
def_unary_op(t, Complex{t}, :(Base.cis), :cis!, :CIS)
178+
# def_unary_op(Complex{t}, Complex{t}, :(Base.conj), :conj!, :Conj)
179+
def_unary_op(Complex{t}, t, :(Base.abs), :abs!, :Abs)
180+
def_unary_op(Complex{t}, t, :(Base.angle), :angle!, :Arg)
181+
182+
# Binary, complex-only. These are more accurate but performance is
183+
# either equivalent to Base or slower.
184+
# def_binary_op(Complex{t}, Complex{t}, :(Base.(:+)), :add!, :Add, false)
185+
# def_binary_op(Complex{t}, Complex{t}, :(Base.(:.+)), :add!, :Add, true)
186+
# def_binary_op(Complex{t}, Complex{t}, :(Base.(:.*)), :multiply!, :Mul, true)
187+
# def_binary_op(Complex{t}, Complex{t}, :(Base.(:-)), :subtract!, :Sub, false)
188+
# def_binary_op(Complex{t}, Complex{t}, :(Base.(:.-)), :subtract!, :Sub, true)
189+
# def_binary_op(Complex{t}, Complex{t}, :multiply_conj, :multiply_conj!, :Mul, false)
149190
end
150191

151192
export VML_LA, VML_HA, VML_EP, vml_set_accuracy, vml_get_accuracy

0 commit comments

Comments
 (0)