MachineIntelligenceCore:ReinforcementLearning
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator
GridworldQLearning.cpp
Go to the documentation of this file.
1 
24 
25 #include <limits>
26 #include <utils/RandomGenerator.hpp>
28 
29 namespace mic {
30 namespace application {
31 
36 void RegisterApplication (void) {
37  REGISTER_APPLICATION(mic::application::GridworldQLearning);
38 }
39 
40 
41 GridworldQLearning::GridworldQLearning(std::string node_name_) : OpenGLEpisodicApplication(node_name_),
42  step_reward("step_reward", 0.0),
43  discount_rate("discount_rate", 0.9),
44  learning_rate("learning_rate", 0.1),
45  move_noise("move_noise",0.2),
46  epsilon("epsilon", 0.1),
47  statistics_filename("statistics_filename","statistics_filename.csv")
48 
49  {
50  // Register properties - so their values can be overridden (read from the configuration file).
51  registerProperty(step_reward);
52  registerProperty(discount_rate);
53  registerProperty(learning_rate);
54  registerProperty(move_noise);
55  registerProperty(epsilon);
56  registerProperty(statistics_filename);
57 
58  LOG(LINFO) << "Properties registered";
59 }
60 
61 
63  delete(w_chart);
64 }
65 
66 
67 void GridworldQLearning::initialize(int argc, char* argv[]) {
68  // Initialize GLUT! :]
69  VGL_MANAGER->initializeGLUT(argc, argv);
70 
71  collector_ptr = std::make_shared < mic::utils::DataCollector<std::string, float> >( );
72  // Add containers to collector.
73  collector_ptr->createContainer("number_of_steps", mic::types::color_rgba(255, 0, 0, 180));
74  collector_ptr->createContainer("average_number_of_steps", mic::types::color_rgba(255, 255, 0, 180));
75  collector_ptr->createContainer("collected_reward", mic::types::color_rgba(0, 255, 0, 180));
76  collector_ptr->createContainer("average_collected_reward", mic::types::color_rgba(0, 255, 255, 180));
77 
79  sum_of_rewards = 0;
80 
81  // Create the visualization windows - must be created in the same, main thread :]
82  w_chart = new WindowCollectorChart<float>("GridworldQLearning", 256, 256, 0, 0);
83  w_chart->setDataCollectorPtr(collector_ptr);
84 
85 }
86 
88  // Initialize the gridworld.
90 
91  // Resize and reset the action-value table.
93  qstate_table.zeros();
94  //qstate_table.setValue( -std::numeric_limits<float>::infinity() );
95 
96  LOG(LSTATUS) << std::endl << streamQStateTable();
97 }
98 
99 
101  LOG(LSTATUS) << "Starting new episode " << episode;
102 
103  // Generate the gridworld (and move player to initial position).
105 
106  LOG(LSTATUS) << std::endl << streamQStateTable();
107  LOG(LSTATUS) << std::endl << grid_env.environmentToString();
108 
109 }
110 
111 
113  LOG(LTRACE) << "End current episode";
114 
116  sum_of_iterations += iteration;
117  sum_of_rewards += reward;
118 
119  // Add variables to container.
120  collector_ptr->addDataToContainer("number_of_steps",iteration);
121  collector_ptr->addDataToContainer("average_number_of_steps",(float)sum_of_iterations/episode);
122  collector_ptr->addDataToContainer("collected_reward", reward);
123  collector_ptr->addDataToContainer("average_collected_reward", (float)sum_of_rewards/episode);
124 
125  // Export reward "convergence" diagram.
126  collector_ptr->exportDataToCsv(statistics_filename);
127 
128 }
129 
130 
131 
133  std::string rewards_table;
134  std::string actions_table;
135 
136  rewards_table += "Action values:\n";
137  actions_table += "Best actions:\n";
138  for (size_t y=0; y<grid_env.getEnvironmentHeight(); y++){
139  rewards_table += "| ";
140  actions_table += "| ";
141  for (size_t x=0; x<grid_env.getEnvironmentWidth(); x++) {
142  // Iterate through actions and find the best one.
143  float bestqval = -std::numeric_limits<float>::infinity();
144  size_t best_action = -1;
145  for (size_t a=0; a<4; a++) {
146  float qval = qstate_table({x,y,a});
147  if ( qstate_table({x,y,a}) == -std::numeric_limits<float>::infinity())
148  rewards_table += "-INF";
149  else
150  rewards_table += std::to_string(qval);
151  if (a==3)
152  rewards_table += " | ";
153  else
154  rewards_table += " , ";
155 
156  // Remember the best value.
157  if (grid_env.isStateAllowed(x,y) && (!grid_env.isStateTerminal(x,y)) && grid_env.isActionAllowed(x,y,a) && (qval > bestqval)){
158  bestqval = qval;
159  best_action = a;
160  }//: if
161 
162  }//: for a(ctions)
163  switch(best_action){
164  case 0 : actions_table += "N | "; break;
165  case 1 : actions_table += "E | "; break;
166  case 2 : actions_table += "S | "; break;
167  case 3 : actions_table += "W | "; break;
168  default: actions_table += "- | ";
169  }//: switch
170 
171  }//: for x
172  rewards_table += "\n";
173  actions_table += "\n";
174  }//: for y
175 
176  return rewards_table + actions_table;
177 }
178 
179 
180 
181 float GridworldQLearning::computeBestValue(mic::types::Position2D pos_){
182  float qbest_value = -std::numeric_limits<float>::infinity();
183  // Check if the state is allowed.
184  if (!grid_env.isStateAllowed(pos_))
185  return qbest_value;
186 
187  // Create a list of possible actions.
188  std::vector<mic::types::NESWAction> actions;
189  actions.push_back(A_NORTH);
190  actions.push_back(A_EAST);
191  actions.push_back(A_SOUTH);
192  actions.push_back(A_WEST);
193 
194  // Check the actions one by one.
195  for(mic::types::NESWAction action : actions) {
196  if(grid_env.isActionAllowed(pos_, action)) {
197  float qvalue = qstate_table({(size_t)pos_.x, (size_t)pos_.y, (size_t)action.getType()});
198  if (qvalue > qbest_value)
199  qbest_value = qvalue;
200  }//if is allowed
201  }//: for
202 
203  return qbest_value;
204 }
205 
206 mic::types::NESWAction GridworldQLearning::selectBestAction(mic::types::Position2D pos_){
207  LOG(LTRACE) << "Select best action for state" << pos_;
208 
209  // Greedy methods - returns the index of element with greatest value.
210  mic::types::NESWAction best_action = A_NONE;
211  float best_qvalue = -std::numeric_limits<float>::infinity();
212 
213  // Create a list of possible actions.
214  std::vector<mic::types::NESWAction> actions;
215  actions.push_back(A_NORTH);
216  actions.push_back(A_EAST);
217  actions.push_back(A_SOUTH);
218  actions.push_back(A_WEST);
219 
220  // Check the actions one by one.
221  for(mic::types::NESWAction action : actions) {
222  if(grid_env.isActionAllowed(pos_, action)) {
223  float qvalue = qstate_table({(size_t)pos_.x, (size_t)pos_.y, (size_t)action.getType()});
224  std::cout << " qvalue = " << qvalue << std::endl;
225  if (qvalue > best_qvalue){
226  best_qvalue = qvalue;
227  best_action = action;
228  std::cout << " best_qvalue = " << best_qvalue << std::endl;
229  }
230  }//if is allowed
231  }//: for
232 
233  return best_action;
234 }
235 
237  LOG(LSTATUS) << "Episode "<< episode << ": step " << iteration << "";
238 
239  // Get state s(t).
240  mic::types::Position2D agent_pos_t = grid_env.getAgentPosition();
241 
242  // Check whether state is terminal.
243  if(grid_env.isStateTerminal(agent_pos_t)) {
244  // In the terminal state we can select only one special action: "terminate".
245  // All "other" actions receive the same value related to the "reward".
246  float final_reward = grid_env.getStateReward(agent_pos_t);
247  for (size_t a=0; a<4; a++)
248  qstate_table({(size_t)agent_pos_t.x,(size_t)agent_pos_t.y, a}) = final_reward;
249 
250  LOG(LINFO) << "Agent action = " << A_EXIT;
251  LOG(LDEBUG) << "Agent position = " << agent_pos_t;
252  LOG(LSTATUS) << std::endl << grid_env.environmentToString();
253  LOG(LSTATUS) << std::endl << streamQStateTable();
254 
255  // Finish the episode.
256  return false;
257  }//: if terminal
258 
259 
260  mic::types::NESWAction action;
261  double eps = (double)epsilon;
262  if ((double)epsilon < 0)
263  eps = 1.0/(1.0+episode);
264  LOG(LDEBUG) << "eps =" << eps;
265  bool random = false;
266 
267  // Epsilon-greedy action selection.
268  if (RAN_GEN->uniRandReal() > eps){
269  // Select best action.
270  action = selectBestAction(agent_pos_t);
271  // If action could not be found.
272  if (action.getType() == mic::types::NESW::None){
273  action = A_RANDOM;
274  random = true;
275  }
276  } else {
277  // Random move - repeat until an allowed move is made.
278  action = A_RANDOM;
279  random = true;
280  }//: if
281 
282  LOG(LINFO) << action << action << ((random) ? " [Random]" : "");
283 
284  // Execture action - until success.
285  grid_env.moveAgent(action);
286 /* while (!move(action)) {
287  // If action could not be performed - random.
288  action = A_RANDOM;
289  random = true;
290  }//: while*/
291 
292  // Get new state s(t+1).
293  mic::types::Position2D agent_pos_t_prim = grid_env.getAgentPosition();
294 
295  LOG(LINFO) << "Agent position at t+1: " << agent_pos_t_prim << " after performing the action = " << action << ((random) ? " [Random]" : "");
296 
297 
298  // Update running average for given action - Q learning;)
299  float q_st_at = qstate_table({(size_t)agent_pos_t.x, (size_t)agent_pos_t.y, (size_t)action.getType()});
300  float r = step_reward;
301  float max_q_st_prim_at_prim = computeBestValue(agent_pos_t_prim);
302  LOG(LDEBUG) << "q_st_at = " << q_st_at;
303  LOG(LDEBUG) << "agent_t_prim = " << agent_pos_t_prim;
304  LOG(LDEBUG) << "step_reward = " << step_reward;
305  LOG(LDEBUG) << "max_q_st_prim_at_prim = " << max_q_st_prim_at_prim;
306  //if (std::isfinite(q_st_at) && std::isfinite(max_q_st_prim_at_prim))
307  if (agent_pos_t == agent_pos_t_prim)
308  qstate_table({(size_t)agent_pos_t.x, (size_t)agent_pos_t.y, (size_t)action.getType()}) = q_st_at + learning_rate * (2*r + discount_rate*max_q_st_prim_at_prim - q_st_at);
309  else
310  qstate_table({(size_t)agent_pos_t.x, (size_t)agent_pos_t.y, (size_t)action.getType()}) = q_st_at + learning_rate * (r + discount_rate*max_q_st_prim_at_prim - q_st_at);
311 
312  LOG(LSTATUS) << std::endl << streamQStateTable();
313  LOG(LSTATUS) << std::endl << grid_env.environmentToString();
314 
315  return true;
316 }
317 
318 
319 } /* namespace application */
320 } /* namespace mic */
mic::types::NESWAction selectBestAction(mic::types::Position2D pos_)
WindowCollectorChart< float > * w_chart
Window for displaying ???.
virtual float getStateReward(mic::types::Position2D pos_)
Definition: Gridworld.cpp:823
virtual mic::types::Position2D getAgentPosition()
Definition: Gridworld.cpp:790
float computeBestValue(mic::types::Position2D pos_)
mic::configuration::Property< double > epsilon
virtual bool isActionAllowed(long x_, long y_, size_t action_)
Definition: Environment.cpp:70
virtual bool isStateTerminal(mic::types::Position2D pos_)
Definition: Gridworld.cpp:849
virtual size_t getEnvironmentWidth()
Definition: Environment.hpp:69
virtual void initialize(int argc, char *argv[])
mic::types::TensorXf qstate_table
Tensor storing values for all states (gridworld w * h * 4 (number of actions)). COL MAJOR(!)...
bool moveAgent(mic::types::Action2DInterface ac_)
Definition: Environment.cpp:48
mic::configuration::Property< float > move_noise
GridworldQLearning(std::string node_name_="application")
virtual std::string environmentToString()
Definition: Gridworld.cpp:689
virtual size_t getEnvironmentHeight()
Definition: Environment.hpp:75
mic::configuration::Property< float > discount_rate
mic::configuration::Property< float > learning_rate
virtual bool isStateAllowed(mic::types::Position2D pos_)
Definition: Gridworld.cpp:834
mic::utils::DataCollectorPtr< std::string, float > collector_ptr
Data collector.
mic::environments::Gridworld grid_env
The gridworld object.
void RegisterApplication(void)
Registers application.
mic::configuration::Property< float > step_reward
virtual void initializeEnvironment()
Definition: Gridworld.cpp:81
Class responsible for solving the gridworld problem with Q-learning.
mic::configuration::Property< std::string > statistics_filename
Property: name of the file to which the statistics will be exported.