|
13 | 13 | "text": [ |
14 | 14 | "\u001b[32m\u001b[1m Status\u001b[22m\u001b[39m `/mnt/E4E0A9C0E0A998F6/github/ReinforcementLearningAnIntroduction.jl/notebooks/Project.toml`\n", |
15 | 15 | " \u001b[90m [31c24e10]\u001b[39m\u001b[37m Distributions v0.22.4\u001b[39m\n", |
16 | | - " \u001b[90m [91a5bcdd]\u001b[39m\u001b[37m Plots v0.28.4\u001b[39m\n", |
17 | | - " \u001b[90m [02c1da58]\u001b[39m\u001b[37m RLIntro v0.2.0 [`..`]\u001b[39m\n", |
18 | | - " \u001b[90m [e575027e]\u001b[39m\u001b[37m ReinforcementLearningBase v0.5.0 [`~/workspace/github/ReinforcementLearningBase.jl`]\u001b[39m\n", |
19 | | - " \u001b[90m [de1b191a]\u001b[39m\u001b[37m ReinforcementLearningCore v0.1.0 [`~/workspace/github/ReinforcementLearningCore`]\u001b[39m\n", |
| 16 | + " \u001b[90m [91a5bcdd]\u001b[39m\u001b[37m Plots v0.29.1\u001b[39m\n", |
| 17 | + " \u001b[90m [02c1da58]\u001b[39m\u001b[37m ReinforcementLearningAnIntroduction v0.2.0 [`..`]\u001b[39m\n", |
| 18 | + " \u001b[90m [e575027e]\u001b[39m\u001b[37m ReinforcementLearningBase v0.5.0 [`../../ReinforcementLearningBase.jl`]\u001b[39m\n", |
| 19 | + " \u001b[90m [de1b191a]\u001b[39m\u001b[37m ReinforcementLearningCore v0.1.0 [`../../ReinforcementLearningCore`]\u001b[39m\n", |
20 | 20 | " \u001b[90m [2913bbd2]\u001b[39m\u001b[37m StatsBase v0.32.0\u001b[39m\n", |
21 | 21 | " \u001b[90m [f3b207a7]\u001b[39m\u001b[37m StatsPlots v0.12.0\u001b[39m\n", |
22 | 22 | " \u001b[90m [2f01184e]\u001b[39m\u001b[37m SparseArrays \u001b[39m\n" |
|
36 | 36 | "name": "stderr", |
37 | 37 | "output_type": "stream", |
38 | 38 | "text": [ |
39 | | - "┌ Info: Precompiling ReinforcementLearningCore [de1b191a-4ae0-4afa-a27b-92d07f46b2d6]\n", |
40 | | - "└ @ Base loading.jl:1273\n", |
41 | | - "┌ Info: Precompiling RLIntro [02c1da58-b9a1-11e8-0212-f9611b8fe936]\n", |
42 | | - "└ @ Base loading.jl:1273\n", |
43 | | - "┌ Warning: Package RLIntro does not have Flux in its dependencies:\n", |
44 | | - "│ - If you have RLIntro checked out for development and have\n", |
45 | | - "│ added Flux as a dependency but haven't updated your primary\n", |
46 | | - "│ environment's manifest file, try `Pkg.resolve()`.\n", |
47 | | - "│ - Otherwise you may need to report an issue with RLIntro\n", |
48 | | - "└ Loading Flux into RLIntro from project dependency, future warnings for RLIntro are suppressed.\n" |
| 39 | + "┌ Info: Precompiling ReinforcementLearningAnIntroduction [02c1da58-b9a1-11e8-0212-f9611b8fe936]\n", |
| 40 | + "└ @ Base loading.jl:1273\n" |
49 | 41 | ] |
50 | 42 | }, |
51 | 43 | { |
|
63 | 55 | } |
64 | 56 | ], |
65 | 57 | "source": [ |
66 | | - "using ReinforcementLearningCore, RLIntro\n", |
67 | | - "using RLIntro.TicTacToe\n", |
| 58 | + "using ReinforcementLearningAnIntroduction\n", |
68 | 59 | "\n", |
69 | 60 | "env = TicTacToeEnv()" |
70 | 61 | ] |
|
125 | 116 | { |
126 | 117 | "data": { |
127 | 118 | "text/plain": [ |
128 | | - "(reward = 0.0, terminal = false, state = 4193, legal_actions_mask = Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 0])" |
| 119 | + "(reward = 0.0, terminal = false, state = 4151, legal_actions_mask = Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 0])" |
129 | 120 | ] |
130 | 121 | }, |
131 | 122 | "execution_count": 5, |
|
221 | 212 | { |
222 | 213 | "data": { |
223 | 214 | "text/plain": [ |
224 | | - "MonteCarloLearner{RLIntro.EveryVisit,TabularApproximator{1,Array{Float64,1}},CachedSampleAvg{Float64},RLIntro.NoSampling}(TabularApproximator{1,Array{Float64,1}}([0.5, 1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5 … 0.5, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.5, 0.5]), 1.0, 0.1, CachedSampleAvg{Float64}(Dict{Float64,SampleAvg}()))" |
| 215 | + "MonteCarloLearner{ReinforcementLearningAnIntroduction.EveryVisit,TabularApproximator{1,Array{Float64,1}},CachedSampleAvg{Float64},ReinforcementLearningAnIntroduction.NoSampling}(TabularApproximator{1,Array{Float64,1}}([0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5 … 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), 1.0, 0.1, CachedSampleAvg{Float64}(Dict{Float64,SampleAvg}()))" |
225 | 216 | ] |
226 | 217 | }, |
227 | 218 | "execution_count": 8, |
|
261 | 252 | ], |
262 | 253 | "source": [ |
263 | 254 | "function create_mapping(role)\n", |
264 | | - " (obs, value_learner) -> begin\n", |
| 255 | + " (obs, learner) -> begin\n", |
265 | 256 | " mask = get_legal_actions_mask(obs)\n", |
266 | 257 | " [\n", |
267 | | - " mask[a] ? value_learner(StateOverriddenObs(obs=obs, state=TicTacToe.get_next_state_id(get_state(obs), role, a))) : 0. # a dummy value \n", |
| 258 | + " mask[a] ? learner(StateOverriddenObs(obs=obs, state=TicTacToe.get_next_state_id(get_state(obs), role, a))) : 0. # a dummy value \n", |
268 | 259 | " for a in action_space\n", |
269 | 260 | " ]\n", |
270 | 261 | " end\n", |
|
273 | 264 | }, |
274 | 265 | { |
275 | 266 | "cell_type": "code", |
276 | | - "execution_count": 15, |
| 267 | + "execution_count": 10, |
277 | 268 | "metadata": {}, |
278 | 269 | "outputs": [], |
279 | 270 | "source": [ |
280 | 271 | "ϵ = 0.01\n", |
281 | 272 | "\n", |
282 | 273 | "π_1 = VBasedPolicy(\n", |
283 | | - " value_learner = learner_1,\n", |
| 274 | + " learner = learner_1,\n", |
284 | 275 | " mapping = create_mapping(TicTacToe.offensive),\n", |
285 | 276 | " explorer = EpsilonGreedyExplorer(ϵ),\n", |
286 | 277 | " )\n", |
287 | 278 | "\n", |
288 | 279 | "π_2 = VBasedPolicy(\n", |
289 | | - " value_learner = learner_2,\n", |
| 280 | + " learner = learner_2,\n", |
290 | 281 | " mapping = create_mapping(TicTacToe.defensive),\n", |
291 | 282 | " explorer = EpsilonGreedyExplorer(ϵ),\n", |
292 | 283 | " );\n", |
|
310 | 301 | }, |
311 | 302 | { |
312 | 303 | "cell_type": "code", |
313 | | - "execution_count": null, |
| 304 | + "execution_count": 11, |
314 | 305 | "metadata": {}, |
315 | 306 | "outputs": [ |
316 | 307 | { |
317 | 308 | "name": "stderr", |
318 | 309 | "output_type": "stream", |
319 | 310 | "text": [ |
320 | | - "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:51\u001b[39mm46\u001b[39m\n" |
| 311 | + "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:42\u001b[39m8:41\u001b[39m\n" |
321 | 312 | ] |
| 313 | + }, |
| 314 | + { |
| 315 | + "data": { |
| 316 | + "text/plain": [ |
| 317 | + "2-element Array{EmptyHook,1}:\n", |
| 318 | + " EmptyHook()\n", |
| 319 | + " EmptyHook()" |
| 320 | + ] |
| 321 | + }, |
| 322 | + "execution_count": 11, |
| 323 | + "metadata": {}, |
| 324 | + "output_type": "execute_result" |
322 | 325 | } |
323 | 326 | ], |
324 | 327 | "source": [ |
|
355 | 358 | }, |
356 | 359 | { |
357 | 360 | "cell_type": "code", |
358 | | - "execution_count": 17, |
| 361 | + "execution_count": 13, |
359 | 362 | "metadata": {}, |
360 | 363 | "outputs": [ |
361 | 364 | { |
|
364 | 367 | "play (generic function with 1 method)" |
365 | 368 | ] |
366 | 369 | }, |
367 | | - "execution_count": 17, |
| 370 | + "execution_count": 13, |
368 | 371 | "metadata": {}, |
369 | 372 | "output_type": "execute_result" |
370 | 373 | } |
|
418 | 421 | }, |
419 | 422 | { |
420 | 423 | "cell_type": "code", |
421 | | - "execution_count": 20, |
| 424 | + "execution_count": 14, |
422 | 425 | "metadata": {}, |
423 | 426 | "outputs": [ |
424 | 427 | { |
|
435 | 438 | "___\n", |
436 | 439 | "isdone = [false], winner = [nothing]\n", |
437 | 440 | "\n", |
438 | | - "__O\n", |
439 | | - "_X_\n", |
440 | 441 | "___\n", |
| 442 | + "_X_\n", |
| 443 | + "O__\n", |
441 | 444 | "isdone = [false], winner = [nothing]\n", |
442 | 445 | "\n", |
443 | | - "Your input:stdin> 1\n", |
444 | | - "X_O\n", |
445 | | - "_X_\n", |
| 446 | + "Your input:stdin> 6\n", |
446 | 447 | "___\n", |
| 448 | + "_X_\n", |
| 449 | + "OX_\n", |
447 | 450 | "isdone = [false], winner = [nothing]\n", |
448 | 451 | "\n", |
449 | | - "X_O\n", |
| 452 | + "_O_\n", |
450 | 453 | "_X_\n", |
451 | | - "__O\n", |
| 454 | + "OX_\n", |
452 | 455 | "isdone = [false], winner = [nothing]\n", |
453 | 456 | "\n", |
454 | 457 | "Your input:stdin> 8\n", |
455 | | - "X_O\n", |
| 458 | + "_O_\n", |
456 | 459 | "_XX\n", |
457 | | - "__O\n", |
| 460 | + "OX_\n", |
458 | 461 | "isdone = [false], winner = [nothing]\n", |
459 | 462 | "\n", |
460 | | - "X_O\n", |
| 463 | + "_O_\n", |
461 | 464 | "OXX\n", |
462 | | - "__O\n", |
| 465 | + "OX_\n", |
463 | 466 | "isdone = [false], winner = [nothing]\n", |
464 | 467 | "\n", |
465 | | - "Your input:stdin> 6\n", |
466 | | - "X_O\n", |
| 468 | + "Your input:stdin> 1\n", |
| 469 | + "XO_\n", |
467 | 470 | "OXX\n", |
468 | | - "_XO\n", |
| 471 | + "OX_\n", |
469 | 472 | "isdone = [false], winner = [nothing]\n", |
470 | 473 | "\n", |
471 | | - "XOO\n", |
| 474 | + "XO_\n", |
472 | 475 | "OXX\n", |
473 | | - "_XO\n", |
| 476 | + "OXO\n", |
474 | 477 | "isdone = [false], winner = [nothing]\n", |
475 | 478 | "\n", |
476 | | - "Your input:stdin> 3\n", |
477 | | - "XOO\n", |
| 479 | + "Your input:stdin> 7\n", |
| 480 | + "XOX\n", |
478 | 481 | "OXX\n", |
479 | | - "XXO\n", |
| 482 | + "OXO\n", |
480 | 483 | "isdone = [true], winner = [nothing]\n", |
481 | 484 | "\n", |
482 | 485 | "Tie!\n" |
|
0 commit comments