1010 params:: Array{Float64, 2} = zeros (na, ns) + initvalue
1111 tracekind = DataType = ReplacingTraces
1212 traces:: T = λ == 0 || tracekind == NoTraces ? NoTraces () : tracekind (ns, na, λ, γ)
13- endvaluepolicy:: Tp = :Sarsa
13+ endvaluepolicy:: Tp = SarsaEndPolicy ()
14+ end
15+ struct SarsaEndPolicy end
16+ struct QLearningEndPolicy end
17+ struct ExpectedSarsaEndPolicy{Tp}
18+ policy:: Tp
1419end
1520Sarsa (; kargs... ) = TDLearner (; kargs... )
16- QLearning (; kargs... ) = TDLearner (; endvaluepolicy = :QLearning , kargs... )
17- ExpectedSarsa (; kargs... ) = TDLearner (; endvaluepolicy = VeryOptimisticEpsilonGreedyPolicy (.1 ), kargs... )
21+ QLearning (; kargs... ) = TDLearner (; endvaluepolicy = QLearningEndPolicy () , kargs... )
22+ ExpectedSarsa (; kargs... ) = TDLearner (; endvaluepolicy = ExpectedSarsaEndPolicy ( VeryOptimisticEpsilonGreedyPolicy (.1 ) ), kargs... )
1823export Sarsa, QLearning, ExpectedSarsa
1924
2025@inline function selectaction (learner:: Union{TDLearner, AbstractPolicyGradient} ,
@@ -31,35 +36,35 @@ reconstructwithparams(learner::TDLearner, w) = reconstruct(learner, params = w)
3136@inline getvaluecheckinf (learner, a, s:: AbstractArray ) = getvalue (learner. params, a, s)
3237@inline checkinf (learner, value) = (value == Inf64 ? learner. unseenvalue : value)
3338
34- @inline function futurevalue (learner, buffer)
39+ @inline function futurevalue (:: QLearningEndPolicy , learner, buffer)
40+ checkinf (learner, maximumbelowInf (getvalue (learner. params, buffer. states[end ])))
41+ end
42+ @inline function futurevalue (:: SarsaEndPolicy , learner, buffer)
43+ getvaluecheckinf (learner, buffer. actions[end ], buffer. states[end ])
44+ end
45+ @inline function futurevalue (p:: ExpectedSarsaEndPolicy , learner, buffer)
3546 a = buffer. actions[end ]
3647 s = buffer. states[end ]
37- if learner. endvaluepolicy == :QLearning
38- checkinf (learner, maximumbelowInf (getvalue (learner. params, s)))
39- elseif learner. endvaluepolicy == :Sarsa
40- getvaluecheckinf (learner, a, s)
41- else
42- actionprobabilites = getactionprobabilities (learner. endvaluepolicy,
43- getvalue (learner. params, s))
44- m = 0.
45- for (a, w) in enumerate (actionprobabilites)
46- if w != 0.
47- m += w * getvaluecheckinf (learner, a, s)
48- end
48+ actionprobabilites = getactionprobabilities (learner. endvaluepolicy. policy,
49+ getvalue (learner. params, s))
50+ m = 0.
51+ for (a, w) in enumerate (actionprobabilites)
52+ if w != 0.
53+ m += w * getvaluecheckinf (learner, a, s)
4954 end
50- m
5155 end
56+ m
5257end
5358
5459@inline function discountedrewards (rewards, done, γ)
5560 gammaeff = 1.
5661 discr = 0.
57- for (r, done ) in zip (rewards, done)
62+ for (r, d ) in zip (rewards, done)
5863 discr += gammaeff * r
59- done && return [ discr; 0. ]
64+ d && return discr, 0.
6065 gammaeff *= γ
6166 end
62- [ discr; gammaeff]
67+ discr, gammaeff
6368end
6469@inline function tderror (rewards, done, γ, startvalue, endvalue)
6570 discr, gammaeff = discountedrewards (rewards, done, γ)
6974function tderror (learner, buffer)
7075 tderror (buffer. rewards, buffer. done, learner. γ,
7176 getvaluecheckinf (learner, buffer. actions[1 ], buffer. states[1 ]),
72- futurevalue (learner, buffer))
77+ futurevalue (learner. endvaluepolicy, learner , buffer))
7378end
7479
7580# update params
0 commit comments