Skip to content

Commit aa69028

Browse files
committed
speed up traces
1 parent c0d2293 commit aa69028

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

src/learner/tdlearning.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
initvalue::Float64 = 0.
99
unseenvalue::Float64 = initvalue == Inf64 ? 0. : initvalue
1010
params::Array{Float64, 2} = zeros(na, ns) + initvalue
11-
tracekind = DataType = ReplacingTraces
12-
traces::T = λ == 0 || tracekind == NoTraces ? NoTraces() : tracekind(ns, na, λ, γ)
11+
tracekind = DataType = λ == 0 ? NoTraces : ReplacingTraces
12+
traces::T = tracekind == NoTraces ? NoTraces() : tracekind(ns, na, λ, γ)
1313
endvaluepolicy::Tp = SarsaEndPolicy()
1414
end
1515
struct SarsaEndPolicy end

src/traces.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,25 @@ end
9191

9292
discounttraces!(t) = discounttraces!(t.trace, t.γλ, t.minimaltracevalue)
9393
@inline function discounttraces!(trace::SparseMatrixCSC, γλ, minimaltracevalue)
94-
BLAS.scale!(γλ, trace.nzval)
95-
if rand() < .01
96-
clamp!(trace.nzval, minimaltracevalue, Inf)
94+
x = trace.nzval
95+
@simd for i in 1:length(x)
96+
@inbounds if x[i] <= minimaltracevalue
97+
x[i] = 0.
98+
else
99+
x[i] *= γλ
100+
end
101+
end
102+
if rand() < .005
103+
dropzeros!(trace)
97104
end
98105
end
99106
@inline discounttraces!(trace, γλ, minimaltracevalue) = BLAS.scale!(γλ, trace)
100-
resettraces!(traces) = BLAS.scale!(0., traces.trace)
107+
@inline resettraces!(traces) = resettrace!(traces.trace)
108+
@inline resettrace!(trace) = BLAS.scale!(0., trace)
109+
@inline function resettrace!(trace::SparseMatrixCSC)
110+
BLAS.scale!(0., trace.nzval)
111+
dropzeros!(trace)
112+
end
101113

102114
function updatetraceandparams!(traces, params, factor)
103115
updatetraceandparams!(traces.trace, params, factor)

0 commit comments

Comments
 (0)