MachineIntelligenceCore:ReinforcementLearning
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator
MNISTDigitDLRERPOMDP.cpp
Go to the documentation of this file.
1 
24 
25 #include <limits>
26 #include <utils/RandomGenerator.hpp>
27 
28 
29 namespace mic {
30 namespace application {
31 
36 void RegisterApplication (void) {
37  REGISTER_APPLICATION(mic::application::MNISTDigitDLRERPOMDP);
38 }
39 
40 
41 MNISTDigitDLRERPOMDP::MNISTDigitDLRERPOMDP(std::string node_name_) : OpenGLEpisodicApplication(node_name_),
42  saccadic_path(new std::vector <mic::types::Position2D>()),
43  step_reward("step_reward", 0.0),
44  discount_rate("discount_rate", 0.9),
45  learning_rate("learning_rate", 0.005),
46  epsilon("epsilon", 0.1),
47  step_limit("step_limit",0),
48  statistics_filename("statistics_filename","mnist_digit_drl_er_statistics.csv"),
49  mlnn_filename("mlnn_filename", "mnist_digit_drl_er_mlnn.txt"),
50  mlnn_save("mlnn_save", false),
51  mlnn_load("mlnn_load", false),
52  experiences(10000,1)
53  {
54  // Register properties - so their values can be overridden (read from the configuration file).
55  registerProperty(step_reward);
56  registerProperty(discount_rate);
57  registerProperty(learning_rate);
58  registerProperty(epsilon);
59  registerProperty(step_limit);
60  registerProperty(statistics_filename);
61  registerProperty(mlnn_filename);
62  registerProperty(mlnn_save);
63  registerProperty(mlnn_load);
64 
65  LOG(LINFO) << "Properties registered";
66 }
67 
68 
70  delete(w_chart);
71  delete(wmd_environment);
72  delete(wmd_observation);
73 }
74 
75 
76 void MNISTDigitDLRERPOMDP::initialize(int argc, char* argv[]) {
77  // Initialize GLUT! :]
78  VGL_MANAGER->initializeGLUT(argc, argv);
79 
80  collector_ptr = std::make_shared < mic::utils::DataCollector<std::string, float> >( );
81  // Add containers to collector.
82  collector_ptr->createContainer("path_length_episode", mic::types::color_rgba(0, 255, 0, 180));
83  collector_ptr->createContainer("path_length_average", mic::types::color_rgba(255, 255, 0, 180));
84  collector_ptr->createContainer("path_length_optimal", mic::types::color_rgba(255, 255, 255, 180));
85  collector_ptr->createContainer("path_length_diff", mic::types::color_rgba(255, 0, 0, 180));
86 
88 
89  // Create the visualization windows - must be created in the same, main thread :]
90  w_chart = new WindowCollectorChart<float>("MNISTDigitDLRERPOMDP", 256, 512, 0, 0);
91  w_chart->setDataCollectorPtr(collector_ptr);
92 
93 }
94 
96  // Create windows for the visualization of the whole environment and a single observation.
97  wmd_environment = new WindowMNISTDigit("Environment", env.getEnvironmentHeight()*20,env.getEnvironmentWidth()*20, 0, 316);
98  wmd_observation = new WindowMNISTDigit("Observation", env.getObservationHeight()*20,env.getObservationWidth()*20, env.getEnvironmentWidth()*20, 316);
99 
100 
101  // Hardcode batchsize - for fastening the display!
103 
104  // Try to load neural network from file.
105  if ((mlnn_load) && (neural_net.load(mlnn_filename))) {
106  // Do nothing ;)
107  } else {
108  // Create a simple neural network.
109  // gridworld wxhx4 -> 100 -> 4 -> regression!.
110  neural_net.pushLayer(new Linear<float>((size_t) env.getObservationSize(), 250));
111  neural_net.pushLayer(new ReLU<float>(250));
112  neural_net.pushLayer(new Linear<float>(250, 100));
113  neural_net.pushLayer(new ReLU<float>(100));
114  neural_net.pushLayer(new Linear<float>(100, 4));
115 
116  // Set batch size.
117  neural_net.resizeBatch(batch_size);
118  // Change optimization function from default GradientDescent to Adam.
119  neural_net.setOptimization<mic::neural_nets::optimization::Adam<float> >();
120  // Set loss function -> regression!
121  neural_net.setLoss <mic::neural_nets::loss::SquaredErrorLoss<float> >();
122 
123  LOG(LINFO) << "Generated new neural network";
124  }//: else
125 
126  // Set batch size in experience replay memory.
127  experiences.setBatchSize(batch_size);
128 
129  // Set displayed matrix pointers.
130  wmd_environment->setDigitPointer(env.getEnvironment());
131  wmd_environment->setPathPointer(saccadic_path);
132  wmd_observation->setDigitPointer(env.getObservation());
133 
134 }
135 
136 
138  LOG(LSTATUS) << "Starting new episode " << episode;
139 
140  // Generate the gridworld (and move player to initial position).
142  saccadic_path->clear();
143  // Add first, initial position to to saccadic path.
144  saccadic_path->push_back(env.getAgentPosition());
145 
146  /*LOG(LNOTICE) << "Network responses: \n" << streamNetworkResponseTable();
147  LOG(LNOTICE) << "Observation: \n" << env.observationToString();
148  LOG(LNOTICE) << "Environment: \n" << env.environmentToString();*/
149  // Do not forget to get the current observation!
151 }
152 
153 
155  LOG(LTRACE) << "End current episode";
156 
157  sum_of_iterations += iteration -1; // -1 is the fix related to moving the terminal condition to the front of step!
158 
159  // Add variables to container.
160  collector_ptr->addDataToContainer("path_length_episode",(iteration -1));
161  collector_ptr->addDataToContainer("path_length_average",(float)sum_of_iterations/episode);
162  collector_ptr->addDataToContainer("path_length_optimal", (float)env.optimalPathLength());
163  collector_ptr->addDataToContainer("path_length_diff", (float)(iteration -1 - env.optimalPathLength()));
164 
165 
166  // Export reward "convergence" diagram.
167  collector_ptr->exportDataToCsv(statistics_filename);
168 
169  // Save nn to file.
170  if (mlnn_save && (episode %10))
172 }
173 
174 
176  LOG(LTRACE) << "streamNetworkResponseTable()";
177  std::string rewards_table;
178  std::string actions_table;
179 
180  // Remember the current state i.e. player position.
181  mic::types::Position2D current_player_pos_t = env.getAgentPosition();
182 
183  // Create new matrices for batches of inputs and targets.
184  MatrixXfPtr inputs_batch(new MatrixXf(env.getObservationSize(), batch_size));
185 
186  // Assume that the batch_size = grid_env.getWidth() * grid_env.getHeight()
188 
189 
190  size_t dx = (env.getObservationWidth()-1)/2;
191  size_t dy = (env.getObservationHeight()-1)/2;
192  mic::types::Position2D p = env.getAgentPosition();
193 
194  // Copy data.
195  for (long oy=0, ey=(p.y-dy); oy<(long)env.getObservationHeight(); oy++, ey++){
196  for (long ox=0, ex=(p.x-dx); ox<(long)env.getObservationWidth(); ox++, ex++) {
197 
198  // Move the player to given state - disregarding whether it was successful or not, answers for walls/positions outside of the gridworld do not interes us anyway...
199  if (!env.moveAgentToPosition(Position2D(ex,ey)))
200  LOG(LDEBUG) << "Failed!"; //... but still we can live with that... ;)
201  // Encode the current state.
202  mic::types::MatrixXfPtr encoded_state = env.encodeObservation();
203  // Add to batch.
204  inputs_batch->col(oy*env.getObservationWidth()+ox) = encoded_state->col(0);
205  }//: for x
206  }//: for y
207 
208  // Get rewards for the whole batch.
209  neural_net.forward(inputs_batch);
210  // Get predictions for all those states - there is no need to create a copy.
211  MatrixXfPtr predicted_batch = neural_net.getPredictions();
212 
213 
214  rewards_table += "Action values:\n";
215  actions_table += "Best actions:\n";
216  // Generate all possible states and all possible rewards.
217  for (long oy=0, ey=(p.y-dy); oy<(long)env.getObservationHeight(); oy++, ey++){
218  rewards_table += "| ";
219  actions_table += "| ";
220  for (long ox=0, ex=(p.x-dx); ox<(long)env.getObservationWidth(); ox++, ex++) {
221  float bestqval = -std::numeric_limits<float>::infinity();
222  size_t best_action = -1;
223  for (size_t a=0; a<4; a++) {
224  float qval = (*predicted_batch)(a, oy*env.getObservationWidth()+ox);
225 
226  rewards_table += std::to_string(qval);
227  if (a==3)
228  rewards_table += " | ";
229  else
230  rewards_table += " , ";
231 
232  // Remember the best value.
233  if (env.isStateAllowed(ex,ey) && (!env.isStateTerminal(ex,ey)) && env.isActionAllowed(ex,ey,a) && (qval > bestqval)){
234  bestqval = qval;
235  best_action = a;
236  }//: if
237 
238  }//: for a(ctions)
239  switch(best_action){
240  case 0 : actions_table += "N | "; break;
241  case 1 : actions_table += "E | "; break;
242  case 2 : actions_table += "S | "; break;
243  case 3 : actions_table += "W | "; break;
244  default: actions_table += "- | ";
245  }//: switch
246 
247  }//: for x
248  rewards_table += "\n";
249  actions_table += "\n";
250  }//: for y
251 
252  // Move player to previous position.
253  env.moveAgentToPosition(current_player_pos_t);
254 
255  return rewards_table + actions_table;
256 }
257 
258 
259 
260 float MNISTDigitDLRERPOMDP::computeBestValueForGivenStateAndPredictions(mic::types::Position2D player_position_, float* predictions_){
261  LOG(LTRACE) << "computeBestValueForGivenState()";
262  float best_qvalue = -std::numeric_limits<float>::infinity();
263 
264  // Create a list of possible actions.
265  std::vector<mic::types::NESWAction> actions;
266  actions.push_back(A_NORTH);
267  actions.push_back(A_EAST);
268  actions.push_back(A_SOUTH);
269  actions.push_back(A_WEST);
270 
271  for(mic::types::NESWAction action : actions) {
272  // .. and find the value of the best allowed action.
273  if(env.isActionAllowed(player_position_, action)) {
274  float qvalue = predictions_[(size_t)action.getType()];
275  if (qvalue > best_qvalue)
276  best_qvalue = qvalue;
277  }//if is allowed
278  }//: for
279 
280  return best_qvalue;
281 }
282 
283 
284 mic::types::MatrixXfPtr MNISTDigitDLRERPOMDP::getPredictedRewardsForGivenState(mic::types::Position2D player_position_) {
285  LOG(LTRACE) << "getPredictedRewardsForGivenState()";
286  // Remember the current state i.e. player position.
287  mic::types::Position2D current_player_pos_t = env.getAgentPosition();
288 
289  // Move the player to given state.
290  env.moveAgentToPosition(player_position_);
291 
292  // Encode the current state.
293  mic::types::MatrixXfPtr encoded_state = env.encodeObservation();
294 
295  // Create NEW matrix for the inputs batch.
296  MatrixXfPtr inputs_batch(new MatrixXf(env.getObservationSize(), batch_size));
297  inputs_batch->setZero();
298 
299  // Set the first input - only this one interests us.
300  inputs_batch->col(0) = encoded_state->col(0);
301 
302  //LOG(LERROR) << "Getting predictions for input batch:\n" <<inputs_batch->transpose();
303 
304  // Pass the data and get predictions.
305  neural_net.forward(inputs_batch);
306 
307  MatrixXfPtr predictions_batch = neural_net.getPredictions();
308 
309  //LOG(LERROR) << "Resulting predictions batch:\n" << predictions_batch->transpose();
310 
311  // Get the first prediction only.
312  MatrixXfPtr predictions_sample(new MatrixXf(4, 1));
313  predictions_sample->col(0) = predictions_batch->col(0);
314 
315  //LOG(LERROR) << "Returned predictions sample:\n" << predictions_sample->transpose();
316 
317  // Move player to previous position.
318  env.moveAgentToPosition(current_player_pos_t);
319 
320  // Return the predictions.
321  return predictions_sample;
322 }
323 
324 mic::types::NESWAction MNISTDigitDLRERPOMDP::selectBestActionForGivenState(mic::types::Position2D player_position_){
325  LOG(LTRACE) << "selectBestAction";
326 
327  // Greedy methods - returns the index of element with greatest value.
328  mic::types::NESWAction best_action = A_RANDOM;
329  float best_qvalue = -std::numeric_limits<float>::infinity();
330 
331  // Create a list of possible actions.
332  std::vector<mic::types::NESWAction> actions;
333  actions.push_back(A_NORTH);
334  actions.push_back(A_EAST);
335  actions.push_back(A_SOUTH);
336  actions.push_back(A_WEST);
337 
338  // Check the results of actions one by one... (there is no need to create a separate copy of predictions)
339  MatrixXfPtr predictions_sample = getPredictedRewardsForGivenState(player_position_);
340  //LOG(LERROR) << "Selecting action from predictions:\n" << predictions_sample->transpose();
341  float* pred = predictions_sample->data();
342 
343  for(size_t a=0; a<4; a++) {
344  // Find the best action allowed.
345  if(env.isActionAllowed(player_position_, mic::types::NESWAction((mic::types::NESW)a))) {
346  float qvalue = pred[a];
347  if (qvalue > best_qvalue){
348  best_qvalue = qvalue;
349  best_action.setAction((mic::types::NESW)a);
350  }
351  }//if is allowed
352  }//: for
353 
354  return best_action;
355 }
356 
358  LOG(LSTATUS) << "Episode "<< episode << ": step " << iteration << "";
359 
360  // Get player pos at time t.
361  mic::types::Position2D player_pos_t= env.getAgentPosition();
362  LOG(LINFO) << "Agent position at state t: " << player_pos_t;
363 
364  // Check whether state t is terminal - finish the episode.
365  if(env.isStateTerminal(player_pos_t))
366  return false;
367 
368  // TMP!
369  double nn_weight_decay = 0;
370 
371  // Select the action.
372  mic::types::NESWAction action;
373  //action = A_NORTH;
374  double eps = (double)epsilon;
375  if ((double)epsilon < 0)
376  eps = 1.0/(1.0+sqrt(episode));
377  if (eps < 0.1)
378  eps = 0.1;
379  LOG(LDEBUG) << "eps = " << eps;
380  bool random = false;
381 
382  // Epsilon-greedy action selection.
383  if (RAN_GEN->uniRandReal() > eps){
384  // Select best action.
385  action = selectBestActionForGivenState(player_pos_t);
386  } else {
387  // Random action.
388  action = A_RANDOM;
389  random = true;
390  }//: if
391 
392  // Execute action - do not monitor the success.
393  env.moveAgent(action);
394 
395  // Get new state s(t+1).
396  mic::types::Position2D player_pos_t_prim = env.getAgentPosition();
397  LOG(LINFO) << "Agent position at t+1: " << player_pos_t_prim << " after performing the action = " << action << ((random) ? " [Random]" : "");
398 
399  // Add this position to to saccadic path.
400  saccadic_path->push_back(player_pos_t_prim);
401 
402  // Collect the experience.
403  SpatialExperiencePtr exp(new SpatialExperience(player_pos_t, action, player_pos_t_prim));
404  // Create an empty matrix for rewards - this will be recalculated each time the experience will be replayed anyway.
405  MatrixXfPtr rewards (new MatrixXf(4 , batch_size));
406  // Add experience to experience table.
407  experiences.add(exp, rewards);
408 
409 
410  // Deep Q learning - train network with random sample from the experience memory.
411  if (experiences.size() >= 2*batch_size) {
412  // Create new matrices for batches of inputs and targets.
413  MatrixXfPtr inputs_t_batch(new MatrixXf(env.getObservationSize(), batch_size));
414  MatrixXfPtr inputs_t_prim_batch(new MatrixXf(env.getObservationSize(), batch_size));
415  MatrixXfPtr targets_t_batch(new MatrixXf(4, batch_size));
416 
417  // Get the random batch.
418  SpatialExperienceBatch geb = experiences.getRandomBatch();
419 
420  // Debug purposes.
421  geb.setNextSampleIndex(0);
422  for (size_t i=0; i<batch_size; i++) {
423  SpatialExperienceSample ges = geb.getNextSample();
424  SpatialExperiencePtr ge_ptr = ges.data();
425  LOG(LDEBUG) << "Training sample : " << ge_ptr->s_t << " -> " << ge_ptr->a_t << " -> " << ge_ptr->s_t_prim;
426  }//: for
427 
428  // Iterate through samples and create inputs_t_batch.
429  for (size_t i=0; i<batch_size; i++) {
430  SpatialExperienceSample ges = geb.getNextSample();
431  SpatialExperiencePtr ge_ptr = ges.data();
432 
433  // Replay the experience.
434  // "Simulate" moving player to position from state/time (t).
435  env.moveAgentToPosition(ge_ptr->s_t);
436  // Encode the state at time (t).
437  mic::types::MatrixXfPtr encoded_state_t = env.encodeObservation();
438  //float* state = encoded_state_t->data();
439 
440  // Copy the encoded state to inputs batch.
441  inputs_t_batch->col(i) = encoded_state_t->col(0);
442  }// for samples.
443 
444  // Get network responses.
445  neural_net.forward(inputs_t_batch);
446  // Get predictions for all those states...
447  MatrixXfPtr predictions_t_batch = neural_net.getPredictions();
448  // ... and copy them to reward pointer - a container which we will modify.
449  (*targets_t_batch) = (*predictions_t_batch);
450 
451  // Iterate through samples and create inputs_t_prim_batch.
452  geb.setNextSampleIndex(0);
453  for (size_t i=0; i<batch_size; i++) {
454  SpatialExperienceSample ges = geb.getNextSample();
455  SpatialExperiencePtr ge_ptr = ges.data();
456 
457  // Replay the experience.
458  // "Simulate" moving player to position from state/time (t+1).
459  env.moveAgentToPosition(ge_ptr->s_t_prim);
460  // Encode the state at time (t+1).
461  mic::types::MatrixXfPtr encoded_state_t = env.encodeObservation();
462  //float* state = encoded_state_t->data();
463 
464  // Copy the encoded state to inputs batch.
465  inputs_t_prim_batch->col(i) = encoded_state_t->col(0);
466  }// for samples.
467 
468  // Get network responses.
469  neural_net.forward(inputs_t_prim_batch);
470  // Get predictions for all those states...
471  MatrixXfPtr predictions_t_prim_batch = neural_net.getPredictions();
472 
473  // Calculate the rewards, one by one.
474  // Iterate through samples and create inputs_t_prim_batch.
475  geb.setNextSampleIndex(0);
476  for (size_t i=0; i<batch_size; i++) {
477  SpatialExperienceSample ges = geb.getNextSample();
478  SpatialExperiencePtr ge_ptr = ges.data();
479 
480  if (ge_ptr->s_t == ge_ptr->s_t_prim) {
481  // The move was not possible! Learn that as well.
482  (*targets_t_batch)((size_t)ge_ptr->a_t.getType(), i) = 3*step_reward;
483  } else if(env.isStateTerminal(ge_ptr->s_t_prim)) {
484  // The position at (t+1) state appears to be terminal - learn the reward.
485  (*targets_t_batch)((size_t)ge_ptr->a_t.getType(), i) = env.getStateReward(ge_ptr->s_t_prim);
486  } else {
487  MatrixXfPtr preds_t_prim (new MatrixXf(4, 1));
488  preds_t_prim->col(0) = predictions_t_prim_batch->col(i);
489  // Get best value for the NEXT state - position from (t+1) state.
490  float max_q_st_prim_at_prim = computeBestValueForGivenStateAndPredictions(ge_ptr->s_t_prim, preds_t_prim->data());
491  // If next state best value is finite.
492  // Update running average for given action - Deep Q learning!
493  if (std::isfinite(max_q_st_prim_at_prim))
494  (*targets_t_batch)((size_t)ge_ptr->a_t.getType(), i) = step_reward + discount_rate*max_q_st_prim_at_prim;
495  }//: else
496 
497  }//: for
498 
499  LOG(LDEBUG) <<"Inputs batch:\n" << inputs_t_batch->transpose();
500  LOG(LDEBUG) <<"Targets batch:\n" << targets_t_batch->transpose();
501 
502  // Perform the Deep-Q-learning.
503  //LOG(LDEBUG) << "Network responses before training:" << std::endl << streamNetworkResponseTable();
504 
505  // Train network with rewards.
506  float loss = neural_net.train (inputs_t_batch, targets_t_batch, learning_rate, nn_weight_decay);
507  LOG(LDEBUG) << "Training loss:" << loss;
508 
509  //LOG(LDEBUG) << "Network responses after training:" << std::endl << streamNetworkResponseTable();
510 
511  // Finish the replay: move the player to REAL, CURRENT POSITION.
512  env.moveAgentToPosition(player_pos_t_prim);
513  }//: if enough experiences
514  else
515  LOG(LWARNING) << "Not enough samples in the experience replay memory!";
516 
517  LOG(LNOTICE) << "Network responses: \n" << streamNetworkResponseTable();
518  LOG(LNOTICE) << "Observation: \n" << env.observationToString();
519  LOG(LNOTICE) << "Environment: \n" << env.environmentToString();
520  // Do not forget to get the current observation!
522 
523  // Check whether we reached maximum number of iterations.
524  if ((step_limit>0) && (iteration > (size_t)step_limit))
525  return false;
526 
527  return true;
528 }
529 
530 } /* namespace application */
531 } /* namespace mic */
mic::types::TensorXfPtr getObservation()
Definition: MNISTDigit.cpp:208
mic::utils::DataCollectorPtr< std::string, float > collector_ptr
Data collector.
virtual bool moveAgentToPosition(mic::types::Position2D pos_)
Definition: MNISTDigit.cpp:253
WindowMNISTDigit * wmd_observation
Window displaying the observation.
virtual bool isStateTerminal(mic::types::Position2D pos_)
Definition: MNISTDigit.cpp:291
virtual std::string observationToString()
Definition: MNISTDigit.cpp:164
WindowCollectorChart< float > * w_chart
Window for displaying statistics.
virtual void add(std::shared_ptr< mic::types::SpatialExperience > input_, std::shared_ptr< mic::types::MatrixXf > target_)
Application of Partially Observable Deep Q-learning with Experience Reply to the MNIST digits problem...
std::shared_ptr< mic::types::SpatialExperience > SpatialExperiencePtr
Shared pointer to spatial experience object.
virtual void initializeEnvironment()
Definition: MNISTDigit.cpp:85
mic::configuration::Property< std::string > mlnn_filename
Property: name of the file to which the neural network will be serialized (or deserialized from)...
mic::configuration::Property< double > epsilon
virtual mic::types::MatrixXfPtr encodeObservation()
Definition: MNISTDigit.cpp:186
Structure storing a spatial experience - a triplet of position in time t, executed action and positio...
WindowMNISTDigit * wmd_environment
Window displaying the whole environment.
mic::configuration::Property< float > discount_rate
std::shared_ptr< std::vector< mic::types::Position2D > > saccadic_path
Saccadic path - a sequence of consecutive agent positions.
virtual size_t getObservationWidth()
Definition: Environment.hpp:87
virtual bool isActionAllowed(long x_, long y_, size_t action_)
Definition: Environment.cpp:70
MNISTDigitDLRERPOMDP(std::string node_name_="application")
mic::configuration::Property< bool > mlnn_save
Property: flad denoting thether the nn should be saved to a file (after every episode end)...
mic::configuration::Property< std::string > statistics_filename
Property: name of the file to which the statistics will be exported.
virtual size_t getEnvironmentWidth()
Definition: Environment.hpp:69
virtual size_t getObservationSize()
Definition: Environment.hpp:99
size_t batch_size
Size of the batch in experience replay - set to the size of maze (width*height).
virtual mic::types::Position2D getAgentPosition()
Definition: MNISTDigit.cpp:238
float computeBestValueForGivenStateAndPredictions(mic::types::Position2D player_position_, float *predictions_)
virtual std::string environmentToString()
Definition: MNISTDigit.cpp:160
unsigned int optimalPathLength()
Definition: MNISTDigit.hpp:159
mic::configuration::Property< float > learning_rate
bool moveAgent(mic::types::Action2DInterface ac_)
Definition: Environment.cpp:48
mic::configuration::Property< float > step_reward
virtual float getStateReward(mic::types::Position2D pos_)
Definition: MNISTDigit.cpp:271
BackpropagationNeuralNetwork< float > neural_net
Multi-layer neural network used for approximation of the Qstate rewards.
virtual bool isStateAllowed(mic::types::Position2D pos_)
Definition: MNISTDigit.cpp:280
virtual size_t getEnvironmentHeight()
Definition: Environment.hpp:75
mic::types::Batch< mic::types::SpatialExperience, mic::types::MatrixXf > SpatialExperienceBatch
Spatial experience replay batch.
mic::types::TensorXfPtr & getEnvironment()
Definition: Environment.hpp:63
void RegisterApplication(void)
Registers application.
virtual void initialize(int argc, char *argv[])
mic::types::NESWAction selectBestActionForGivenState(mic::types::Position2D player_position_)
mic::types::Sample< mic::types::SpatialExperience, mic::types::MatrixXf > SpatialExperienceSample
Spatial experience replay sample.
mic::configuration::Property< int > step_limit
mic::configuration::Property< bool > mlnn_load
Property: flad denoting thether the nn should be loaded from a file (at the initialization of the tas...
mic::environments::MNISTDigit env
The maze of digits environment.
mic::types::MatrixXfPtr getPredictedRewardsForGivenState(mic::types::Position2D player_position_)
virtual size_t getObservationHeight()
Definition: Environment.hpp:93