Skip to content

Commit 935313b

Browse files
committed
cleaner distinction of Sarsa, ExpSarsa and QLearning
1 parent bcd9bea commit 935313b

File tree

1 file changed

+26
-21
lines changed

1 file changed

+26
-21
lines changed

src/learner/tdlearning.jl

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,16 @@
1010
params::Array{Float64, 2} = zeros(na, ns) + initvalue
1111
tracekind = DataType = ReplacingTraces
1212
traces::T = λ == 0 || tracekind == NoTraces ? NoTraces() : tracekind(ns, na, λ, γ)
13-
endvaluepolicy::Tp = :Sarsa
13+
endvaluepolicy::Tp = SarsaEndPolicy()
14+
end
15+
struct SarsaEndPolicy end
16+
struct QLearningEndPolicy end
17+
struct ExpectedSarsaEndPolicy{Tp}
18+
policy::Tp
1419
end
1520
Sarsa(; kargs...) = TDLearner(; kargs...)
16-
QLearning(; kargs...) = TDLearner(; endvaluepolicy = :QLearning, kargs...)
17-
ExpectedSarsa(; kargs...) = TDLearner(; endvaluepolicy = VeryOptimisticEpsilonGreedyPolicy(.1), kargs...)
21+
QLearning(; kargs...) = TDLearner(; endvaluepolicy = QLearningEndPolicy(), kargs...)
22+
ExpectedSarsa(; kargs...) = TDLearner(; endvaluepolicy = ExpectedSarsaEndPolicy(VeryOptimisticEpsilonGreedyPolicy(.1)), kargs...)
1823
export Sarsa, QLearning, ExpectedSarsa
1924

2025
@inline function selectaction(learner::Union{TDLearner, AbstractPolicyGradient},
@@ -31,35 +36,35 @@ reconstructwithparams(learner::TDLearner, w) = reconstruct(learner, params = w)
3136
@inline getvaluecheckinf(learner, a, s::AbstractArray) = getvalue(learner.params, a, s)
3237
@inline checkinf(learner, value) = (value == Inf64 ? learner.unseenvalue : value)
3338

34-
@inline function futurevalue(learner, buffer)
39+
@inline function futurevalue(::QLearningEndPolicy, learner, buffer)
40+
checkinf(learner, maximumbelowInf(getvalue(learner.params, buffer.states[end])))
41+
end
42+
@inline function futurevalue(::SarsaEndPolicy, learner, buffer)
43+
getvaluecheckinf(learner, buffer.actions[end], buffer.states[end])
44+
end
45+
@inline function futurevalue(p::ExpectedSarsaEndPolicy, learner, buffer)
3546
a = buffer.actions[end]
3647
s = buffer.states[end]
37-
if learner.endvaluepolicy == :QLearning
38-
checkinf(learner, maximumbelowInf(getvalue(learner.params, s)))
39-
elseif learner.endvaluepolicy == :Sarsa
40-
getvaluecheckinf(learner, a, s)
41-
else
42-
actionprobabilites = getactionprobabilities(learner.endvaluepolicy,
43-
getvalue(learner.params, s))
44-
m = 0.
45-
for (a, w) in enumerate(actionprobabilites)
46-
if w != 0.
47-
m += w * getvaluecheckinf(learner, a, s)
48-
end
48+
actionprobabilites = getactionprobabilities(learner.endvaluepolicy.policy,
49+
getvalue(learner.params, s))
50+
m = 0.
51+
for (a, w) in enumerate(actionprobabilites)
52+
if w != 0.
53+
m += w * getvaluecheckinf(learner, a, s)
4954
end
50-
m
5155
end
56+
m
5257
end
5358

5459
@inline function discountedrewards(rewards, done, γ)
5560
gammaeff = 1.
5661
discr = 0.
57-
for (r, done) in zip(rewards, done)
62+
for (r, d) in zip(rewards, done)
5863
discr += gammaeff * r
59-
done && return [discr; 0.]
64+
d && return discr, 0.
6065
gammaeff *= γ
6166
end
62-
[discr; gammaeff]
67+
discr, gammaeff
6368
end
6469
@inline function tderror(rewards, done, γ, startvalue, endvalue)
6570
discr, gammaeff = discountedrewards(rewards, done, γ)
@@ -69,7 +74,7 @@ end
6974
function tderror(learner, buffer)
7075
tderror(buffer.rewards, buffer.done, learner.γ,
7176
getvaluecheckinf(learner, buffer.actions[1], buffer.states[1]),
72-
futurevalue(learner, buffer))
77+
futurevalue(learner.endvaluepolicy, learner, buffer))
7378
end
7479

7580
# update params

0 commit comments

Comments
 (0)