Skip to content

Commit bcd9bea

Browse files
committed
speed up epsilon greedy policies
1 parent dfe12d5 commit bcd9bea

File tree

2 files changed

+50
-41
lines changed

2 files changed

+50
-41
lines changed

src/epsilongreedypolicies.jl

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,6 @@ abstract type AbstractEpsilonGreedyPolicy end
55
VeryOptimisticEpsilonGreedyPolicy,
66
PesimisticEpsilonGreedyPolicy)
77

8-
9-
for (typ, max, rel) in ((OptimisticEpsilonGreedyPolicy, maximumbelowInf, :(>=)),
10-
(VeryOptimisticEpsilonGreedyPolicy, maximum, :(==)),
11-
(PesimisticEpsilonGreedyPolicy, maximumbelowInf, :(==)))
12-
@eval function getgreedystates(policy::$typ, values)
13-
a = Int64[]
14-
vmax = $max(values)
15-
if isnan(vmax)
16-
error("NaN encountered in getgreedystates: $values")
17-
end
18-
for (i, v) in enumerate(values)
19-
if ($rel)(v, vmax)
20-
push!(a, i)
21-
end
22-
end
23-
a
24-
end
25-
end
26-
278
const EpsilonGreedyPolicy = VeryOptimisticEpsilonGreedyPolicy
289
export EpsilonGreedyPolicy
2910

@@ -63,21 +44,43 @@ where never chosen before.
6344
""" PesimisticEpsilonGreedyPolicy
6445

6546

66-
function selectaction(policy::AbstractEpsilonGreedyPolicy, values)
67-
if rand() < policy.ϵ
68-
rand(1:length(values))
69-
else
70-
rand(getgreedystates(policy, values))
47+
for (typ, max, rel) in ((OptimisticEpsilonGreedyPolicy, maximumbelowInf, :(>=)),
48+
(VeryOptimisticEpsilonGreedyPolicy, maximum, :(==)),
49+
(PesimisticEpsilonGreedyPolicy, maximumbelowInf, :(==)))
50+
@eval function selectaction(policy::$typ, values)
51+
if rand() < policy.ϵ
52+
rand(1:length(values))
53+
else
54+
vmax = $max(values)
55+
c = 1
56+
a = 1
57+
for (i, v) in enumerate(values)
58+
if ($rel)(v, vmax)
59+
if rand() < 1/c
60+
a = i
61+
end
62+
c += 1
63+
end
64+
end
65+
a
66+
end
7167
end
72-
end
73-
74-
function getactionprobabilities(policy::AbstractEpsilonGreedyPolicy, values)
75-
p = ones(length(values))/length(values) * policy.ϵ
76-
a = getgreedystates(policy, values)
77-
p2 = (1. - policy.ϵ)/length(a)
78-
for i in a
79-
p[i] =+ p2
68+
@eval function getactionprobabilities(policy::$typ, values)
69+
p = ones(length(values))/length(values) * policy.ϵ
70+
vmax = $max(values)
71+
c = 0
72+
for v in values
73+
if ($rel)(v, vmax)
74+
c += 1
75+
end
76+
end
77+
p2 = (1. - policy.ϵ)/c
78+
for (i, v) in enumerate(values)
79+
if ($rel)(v, vmax)
80+
p[i] += p2
81+
end
82+
end
83+
p
8084
end
81-
p
8285
end
8386

test/epsilongreedypolicies.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
getgreedystates = ReinforcementLearning.getgreedystates
2-
for (v, rO, rVO, r, rP) in (([-9., 12., Inf64], [2, 3], [3], [3], [2]),
3-
([-9., -12.], [1], [1], [1], [1]),
4-
([Inf64, Inf64], [1, 2], [1, 2], [1, 2], [1, 2]))
5-
@test getgreedystates(OptimisticEpsilonGreedyPolicy(0.), v) == rO
6-
@test getgreedystates(VeryOptimisticEpsilonGreedyPolicy(0.), v) == rVO
7-
@test getgreedystates(PesimisticEpsilonGreedyPolicy(0.), v) == rP
1+
import ReinforcementLearning: selectaction
2+
3+
function empiricalactionprop(p, v; n = 10^6)
4+
res = [selectaction(p, v) for _ in 1:n]
5+
map(x -> length(find(i -> i == x, res)), 1:length(v))./n
86
end
97

108
for (v, rO, rVO, r, rP) in (([-9., 12., Inf64], [0, .5, .5], [0, 0., 1.],
@@ -16,7 +14,15 @@ for (v, rO, rVO, r, rP) in (([-9., 12., Inf64], [0, .5, .5], [0, 0., 1.],
1614
@test getactionprobabilities(OptimisticEpsilonGreedyPolicy(0.), v) == rO
1715
@test getactionprobabilities(VeryOptimisticEpsilonGreedyPolicy(0.), v) == rVO
1816
@test getactionprobabilities(PesimisticEpsilonGreedyPolicy(0.), v) == rP
17+
@test isapprox(empiricalactionprop(OptimisticEpsilonGreedyPolicy(0.), v),
18+
rO, atol = .05)
19+
@test isapprox(empiricalactionprop(VeryOptimisticEpsilonGreedyPolicy(0.), v),
20+
rVO, atol = .05)
21+
@test isapprox(empiricalactionprop(PesimisticEpsilonGreedyPolicy(0.), v),
22+
rP, atol = .05)
23+
@test isapprox(empiricalactionprop(OptimisticEpsilonGreedyPolicy(.2), v),
24+
getactionprobabilities(OptimisticEpsilonGreedyPolicy(.2), v),
25+
atol = .05)
1926
end
2027

2128

22-

0 commit comments

Comments
 (0)