1+ classdef DrivingScenarioEnv < rl .env .MATLABEnvironment
2+ % Copyright 2020 The MathWorks, Inc.
3+ % MYENVIRONMENT: Template for defining custom environment in MATLAB.
4+
5+ % parameters for simulation environment
6+ properties
7+ scenario
8+ network
9+ traffic
10+ cars
11+ state
12+ driver
13+ InjectionRate
14+ TurnRatio
15+ N = 3 % number of road
16+ phaseDuration = 50 % time duration for each of the phase
17+ T
18+ end
19+
20+ % simulation doesn't have yellow light
21+ % manually set up clearning phase here if needed
22+ properties
23+ clearingPhase = false
24+ clearingPhaseTime = 0
25+ TrafficSignalDesign
26+ ObservationSpaceDesign
27+ end
28+
29+ % parameter for reward definition
30+ properties
31+ rewardForPass = 0
32+ vehicleEnterJunction % keep record of cars pass the intersection
33+ hitPenalty = 20
34+ penaltyForFreqSwitch = 1
35+ safeDistance = 2.25 % check collision
36+ slowSpeedThreshold = 3.5 % check whether car is waiting
37+ end
38+
39+ properties
40+ recordVid = false
41+ vid
42+ end
43+
44+ properties
45+ discrete_action = [0 1 2 ];
46+ dim = 10 ;
47+ end
48+
49+ properties (Access = protected )
50+ IsDone = false
51+ end
52+
53+ %% Necessary Methods
54+ methods
55+ function this = DrivingScenarioEnv()
56+ % Initialize Observation settings
57+ ObservationInfo = rlNumericSpec([10 , 1 ]); % # of state
58+ ObservationInfo.Name = ' real-time traffic information' ;
59+ ObservationInfo.Description = ' ' ;
60+
61+ % Initialize action settings
62+ ActionInfo = rlFiniteSetSpec([0 1 2 ]); % three phases
63+ ActionInfo.Name = ' traffic signal phases' ;
64+
65+ % The following line implements built-in functions of the RL environment
66+ this = this@rl.env.MATLABEnvironment(ObservationInfo ,ActionInfo );
67+ end
68+
69+ function [state , Reward ,IsDone ,LoggedSignals ] = step(this , Action )
70+ Action = getForce(this , Action );
71+ % update the action
72+ pre_phase = this .traffic .IsOpen ;
73+ if this .TrafficSignalDesign == 1
74+ cur_phase = signalPhaseDesign1(Action );
75+ elseif this .TrafficSignalDesign == 2
76+ cur_phase = signalPhaseDesign2(Action );
77+ elseif this .TrafficSignalDesign == 3
78+ cur_phase = signalPhaseDesign3(Action );
79+ end
80+
81+ % Reward: penalty for signal phase switch
82+ changed = ~isequal(pre_phase , cur_phase );
83+ Reward = this .penaltyForFreqSwitch * (1 - changed );
84+
85+ % (yellow light time)add clearing phase when signal phase switch
86+ if changed && this .clearingPhase
87+ for i = 1 : this .clearingPhaseTime
88+ this.traffic.IsOpen = [0 , 0 , 0 , 0 , 0 , 0 ];
89+ advance(this .scenario );
90+ this.T = this .T + this .scenario .SampleTime ;
91+ notifyEnvUpdated(this );
92+ % check terminal condition
93+ IsHit = checkCollision(this );
94+ Reward = Reward - IsHit * this .hitPenalty ;
95+ this.IsDone = IsHit || this .T + 0.5 >= this .scenario .StopTime ;
96+ if this .IsDone
97+ break
98+ end
99+ end
100+ end
101+
102+ % (green light time)simulate the signal phase based on the action by RL
103+ this.traffic.IsOpen = cur_phase ;
104+ if ~this .IsDone
105+ for i = 1 : this .phaseDuration
106+ % update traffic state
107+ advance(this .scenario );
108+ this.T = this .T + this .scenario .SampleTime ;
109+ % update visulization
110+ notifyEnvUpdated(this );
111+ % check terminal condition
112+ IsHit = checkCollision(this );
113+ Reward = Reward - IsHit * this .hitPenalty ;
114+ this.IsDone = IsHit || this .T + 0.5 >= this .scenario .StopTime ;
115+ if this .IsDone
116+ break
117+ end
118+ % obtain reward
119+ Reward = Reward + obtainReward(this , cur_phase );
120+ end
121+ end
122+ if this .ObservationSpaceDesign == 1
123+ state = observationSpace1(this , Action );
124+ else
125+ state = observationSpace2(this , Action );
126+ end
127+ this.state = state ;
128+ IsDone = this .IsDone ;
129+ LoggedSignals = [];
130+ end
131+
132+
133+ function InitialState = reset(this )
134+ % flag for record simulation
135+ this.recordVid = false ;
136+ % Initialize scenario
137+ this.scenario = createTJunctionScenario();
138+ this.scenario.StopTime = 100 ;
139+ this.scenario.SampleTime = 0.05 ;
140+ this.T = 0 ;
141+ % initialize network
142+ this.network = createTJunctionNetwork(this .scenario );
143+ this.traffic = trafficControl .TrafficController(this .network(7 : 12 ));
144+ % car parameters
145+ this.InjectionRate = [250 , 250 , 250 ]; % veh/hour
146+ this.TurnRatio = [50 , 50 ];
147+ this.cars = createVehiclesForTJunction(this .scenario , this .network , this .InjectionRate , this .TurnRatio );
148+ this.vehicleEnterJunction = [];
149+ % obtain state from traffic and network
150+ if this .ObservationSpaceDesign == 1
151+ InitialState = observationSpace1(this , 0 );
152+ else
153+ InitialState = observationSpace2(this , 0 );
154+ end
155+ % visulization
156+ notifyEnvUpdated(this );
157+ end
158+ end
159+
160+ methods
161+ function force = getForce(this ,action )
162+ if ~ismember(action ,this .ActionInfo .Elements )
163+ error(' Action must be integer from 1 to numAction' );
164+ end
165+ force = action ;
166+ end
167+ % update the action info based on max force
168+ function updateActionInfo(this )
169+ this.ActionInfo.Elements = this .discrete_action ;
170+ end
171+ end
172+
173+ methods (Access = protected )
174+ function envUpdatedCallback(this )
175+ if this .T == 0
176+ close all ;
177+ plot(this .scenario )
178+ set(gcf ,' Visible' ,' On' );
179+ if this .recordVid
180+ this.vid = VideoWriter(' baseRLlearningProcess33' );
181+ this.vid.FrameRate= 20 ;
182+ open(this .vid )
183+ end
184+ end
185+ if this .recordVid
186+ frame = getframe(gcf );
187+ writeVideo(this .vid ,frame );
188+ end
189+ this .traffic .plotOpenPaths()
190+ drawnow
191+ end
192+ end
193+ end
0 commit comments