MachineIntelligenceCore:ReinforcementLearning
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator
Gridworld.cpp
Go to the documentation of this file.
1 
23 #include <types/Gridworld.hpp>
24 
25 namespace mic {
26 namespace environments {
27 
28 Gridworld::Gridworld(std::string node_name_) : Environment(node_name_),
29  type("type", 0)
30 {
31  // Register properties - so their values can be overridden (read from the configuration file).
32  registerProperty(type);
33 
35 
36 }
37 
38 Gridworld::Gridworld (const mic::environments::Gridworld & gw_) : Environment(gw_.getNodeName()+"_copy"),
39  type("type", gw_.type)
40 {
41  // Register properties - so their values can be overridden (read from the configuration file).
42  registerProperty(type);
43  // Not used, but still let's copy it.
44  type = gw_.type;
45  // Copy size.
46  width = gw_.width;
47  height = gw_.height;
48  channels = gw_.channels;
49  // Copy the environment.
53 }
54 
55 
56 
58  // TODO Auto-generated destructor stub
59 }
60 
62  // Not used, but still let's copy it.
63  type = gw_.type;
64  // Copy size.
65  width = gw_.width;
66  height = gw_.height;
67  channels = gw_.channels;
68  // Copy the environment.
72  // Return pointer to updated instance.
73  return *this;
74 }
75 
76 
78  // Empty - everything will be initialized in environment initialization.
79 }
80 
82  // Generate adequate gridworld.
83  switch(type) {
84  case 0 : initExemplaryGrid(); break;
85  case 1 : initClassicCliffGrid(); break;
86  case 2 : initDiscountGrid(); break;
87  case 3 : initBridgeGrid(); break;
88  case 4 : initBookGrid(); break;
89  case 5 : initMazeGrid(); break;
90  case 6 : initExemplaryDQLGrid(); break;
91  case 7 : initModifiedDQLGrid(); break;
92  case 8 : initDebug2x2Grid(); break;
93  case 9 : initDebug3x3Grid(); break;
94  case -2: initHardRandomGrid(); break;
95  case -1:
96  default: initSimpleRandomGrid();
97  }//: switch
98 
99  // Check whether it is a POMDP or not.
100  if (roi_size >0) {
101  pomdp_flag = true;
103  } else {
104  observation_grid->resize({width, height, channels});
105  }//: else
106 
107 }
108 
109 
111  LOG(LINFO) << "Generating exemplary gridworld";
112  // [[' ',' ',' ',' '],
113  // ['S',-10,' ',' '],
114  // [' ','','#',' '],
115  // [' ',' ',' ',10]]
116 
117  // Overwrite dimensions.
118  width = 4;
119  height = 4;
120 
121  // Set gridworld size.
122  environment_grid->resize({width, height, channels});
123  environment_grid->zeros();
124 
125  // Place the agent.
126  initial_position.set(0,1);
128 
129  // Place wall(s).
130  (*environment_grid)({2,2, (size_t)GridworldChannels::Walls}) = 1;
131 
132  // Place pit(s).
133  (*environment_grid)({1,1, (size_t)GridworldChannels::Pits}) = -10;
134 
135  // Place goal(s).
136  (*environment_grid)({3,3, (size_t)GridworldChannels::Goals}) = 10;
137 }
138 
139 
141  LOG(LINFO) << "Generating classic cliff gridworld";
142  // [[' ',' ',' ',' ',' '],
143  // ['S',' ',' ',' ',10],
144  // [-100,-100, -100, -100, -100]]
145 
146  // Overwrite dimensions.
147  width = 5;
148  height = 3;
149 
150  // Set gridworld size.
151  environment_grid->resize({width, height, channels});
152  environment_grid->zeros();
153 
154  // Place the agent.
155  initial_position.set(0,1);
157 
158  // Place pit(s).
159  for(size_t x=0; x<width; x++)
160  (*environment_grid)({x,2, (size_t)GridworldChannels::Pits}) = -100;
161 
162  // Place goal(s).
163  (*environment_grid)({4,1, (size_t)GridworldChannels::Goals}) = 10;
164 }
165 
167  LOG(LINFO) << "Generating classic discount gridworld";
168  // [[' ',' ',' ',' ',' '],
169  // [' ','#',' ',' ',' '],
170  // [' ','#', 1,'#', 10],
171  // ['S',' ',' ',' ',' '],
172  // [-10,-10, -10, -10, -10]]
173 
174  // Overwrite dimensions.
175  width = 5;
176  height = 5;
177 
178  // Set gridworld size.
179  (*environment_grid).resize({width, height, channels});
180  (*environment_grid).zeros();
181 
182  // Place the agent.
183  initial_position.set(0,3);
185 
186  // Place pits.
187  for(size_t x=0; x<width; x++)
188  (*environment_grid)({x,4, (size_t)GridworldChannels::Pits}) = -10;
189 
190  // Place wall(s).
191  (*environment_grid)({1,1, (size_t)GridworldChannels::Walls}) = 1;
192  (*environment_grid)({1,2, (size_t)GridworldChannels::Walls}) = 1;
193  (*environment_grid)({3,2, (size_t)GridworldChannels::Walls}) = 1;
194 
195  // Place goal(s).
196  (*environment_grid)({2,2, (size_t)GridworldChannels::Goals}) = 1;
197  (*environment_grid)({4,2, (size_t)GridworldChannels::Goals}) = 10;
198 }
199 
200 
202  LOG(LINFO) << "Generating classic bridge gridworld";
203  // [[ '#',-100, -100, -100, -100, -100, '#'],
204  // [ 1, 'S', ' ', ' ', ' ', ' ', 10],
205  // [ '#',-100, -100, -100, -100, -100, '#']]
206 
207  // Overwrite dimensions.
208  width = 7;
209  height = 3;
210 
211  // Set environment_grid size.
212  environment_grid->resize({width, height, channels});
213  environment_grid->zeros();
214 
215  // Place the agent.
216  initial_position.set(1,1);
218 
219  // Place pits.
220  for(size_t x=1; x<width-1; x++) {
221  (*environment_grid)({x,0, (size_t)GridworldChannels::Pits}) = -100;
222  (*environment_grid)({x,2, (size_t)GridworldChannels::Pits}) = -100;
223  }//: for
224 
225  // Place wall(s).
226  (*environment_grid)({0,0, (size_t)GridworldChannels::Walls}) = 1;
227  (*environment_grid)({0,2, (size_t)GridworldChannels::Walls}) = 1;
228  (*environment_grid)({6,0, (size_t)GridworldChannels::Walls}) = 1;
229  (*environment_grid)({6,2, (size_t)GridworldChannels::Walls}) = 1;
230 
231  // Place goal(s).
232  (*environment_grid)({0,1, (size_t)GridworldChannels::Goals}) = 1;
233  (*environment_grid)({6,1, (size_t)GridworldChannels::Goals}) = 10;
234 }
235 
236 
238  LOG(LINFO) << "Generating classic book environment_grid!!";
239  // [[' ',' ',' ',+1],
240  // [' ','#',' ',-1],
241  // ['S',' ',' ',' ']]
242 
243  // Overwrite dimensions.
244  width = 4;
245  height = 3;
246 
247  // Set environment_grid size.
248  environment_grid->resize({width, height, channels});
249  environment_grid->zeros();
250 
251  // Place the agent.
252  initial_position.set(0,2);
254 
255  // Place wall(s).
256  (*environment_grid)({1,1, (size_t)GridworldChannels::Walls}) = 1;
257 
258  // Place pit(s).
259  (*environment_grid)({3,1, (size_t)GridworldChannels::Pits}) = -1;
260 
261  // Place goal(s).
262  (*environment_grid)({3,0, (size_t)GridworldChannels::Goals}) = 1;
263 }
264 
265 
267  LOG(LINFO) << "Generating classic maze gridworld";
268  // [[' ',' ',' ',+1],
269  // ['#','#',' ','#'],
270  // [' ','#',' ',' '],
271  // [' ','#','#',' '],
272  // ['S',' ',' ',' ']]
273 
274  // Overwrite dimensions.
275  width = 4;
276  height = 5;
277 
278  // Set environment_grid size.
279  environment_grid->resize({width, height, channels});
280  environment_grid->zeros();
281 
282  // Place the agent.
283  initial_position.set(0,4);
285 
286  // Place wall(s).
287  (*environment_grid)({0,1, (size_t)GridworldChannels::Walls}) = 1;
288  (*environment_grid)({1,1, (size_t)GridworldChannels::Walls}) = 1;
289  (*environment_grid)({1,2, (size_t)GridworldChannels::Walls}) = 1;
290  (*environment_grid)({1,3, (size_t)GridworldChannels::Walls}) = 1;
291  (*environment_grid)({2,3, (size_t)GridworldChannels::Walls}) = 1;
292  (*environment_grid)({3,1, (size_t)GridworldChannels::Walls}) = 1;
293 
294  // Place goal(s).
295  (*environment_grid)({3,0, (size_t)GridworldChannels::Goals}) = 1;
296 }
297 
298 
300  LOG(LINFO) << "Generating environment_grid from Deep Q-Learning example";
301  /*
302  * [[' ',' ',' ',' '],
303  * [' ',' ',+10,' '],
304  * [' ','#',-10,' '],
305  * ['S',' ',' ',' ']]
306  */
307 
308  // Overwrite dimensions.
309  width = 4;
310  height = 4;
311 
312  // Set environment_grid size.
313  environment_grid->resize({width, height, channels});
314  environment_grid->zeros();
315 
316  // Place the agent.
317  initial_position.set(0,3);
319 
320  // Place wall(s).
321  (*environment_grid)({1,2, (size_t)GridworldChannels::Walls}) = 1;
322 
323  // Place pit(s).
324  (*environment_grid)({2,2, (size_t)GridworldChannels::Pits}) = -10;
325 
326  // Place goal(s).
327  (*environment_grid)({2,1, (size_t)GridworldChannels::Goals}) = 10;
328 }
329 
331  LOG(LINFO) << "Generating a slightly modified grid from Deep Q-Learning example";
332  /*
333  * [[' ',' ',' ',' '],
334  * [' ','#',+10,' '],
335  * [' ',' ',-10,' '],
336  * ['S',' ',' ',' ']]
337  */
338 
339  // Overwrite dimensions.
340  width = 4;
341  height = 4;
342 
343  // Set environment_grid size.
344  environment_grid->resize({width, height, channels});
345  environment_grid->zeros();
346 
347  // Place the agent.
348  initial_position.set(0,3);
350 
351  // Place wall(s).
352  (*environment_grid)({1,1, (size_t)GridworldChannels::Walls}) = 1;
353 
354  // Place pit(s).
355  (*environment_grid)({2,2, (size_t)GridworldChannels::Pits}) = -10;
356 
357  // Place goal(s).
358  (*environment_grid)({2,1, (size_t)GridworldChannels::Goals}) = 10;
359 }
360 
361 
363  LOG(LINFO) << "Generating the 2x2 debug grid";
364  /*
365  * [['S',-10],
366  * [+10,' ']]
367  */
368 
369  // Overwrite dimensions.
370  width = 2;
371  height = 2;
372 
373  // Set environment_grid size.
374  environment_grid->resize({width, height, channels});
375  environment_grid->zeros();
376 
377  // Place the agent.
378  initial_position.set(0,0);
380 
381  // Place pit(s).
382  (*environment_grid)({1,0, (size_t)GridworldChannels::Pits}) = -10;
383 
384  // Place goal(s).
385  (*environment_grid)({0,1, (size_t)GridworldChannels::Goals}) = 10;
386 }
387 
388 
394  LOG(LINFO) << "Generating the 3x3 debug grid";
395  /*
396  * [[' ',-10,' '],
397  * [-10,'S',-10],
398  * [' ',+10,' ']]
399  */
400 
401  // Overwrite the dimensions.
402  width = 3;
403  height = 3;
404 
405  // Set environment_grid size.
406  environment_grid->resize({width, height, channels});
407  environment_grid->zeros();
408 
409  // Place the agent.
410  initial_position.set(1,1);
412 
413  // Place wall(s).
414  (*environment_grid)({1,2, (size_t)GridworldChannels::Walls}) = 1;
415 
416  // Place pit(s).
417  (*environment_grid)({0,1, (size_t)GridworldChannels::Pits}) = -10;
418  (*environment_grid)({1,0, (size_t)GridworldChannels::Pits}) = -10;
419  (*environment_grid)({2,1, (size_t)GridworldChannels::Pits}) = -10;
420 
421  // Place goal(s).
422  (*environment_grid)({1,2, (size_t)GridworldChannels::Goals}) = 10;
423 
424 }
425 
426 
428  LOG(LINFO) << "Generating simple " << width << "x" << height<< " random grid";
429 
430  // Set environment_grid size.
431  environment_grid->resize({width, height, channels});
432  environment_grid->zeros();
433 
434  // Place the agent.
435  mic::types::Position2D agent(0, width-1, 0, height-1);
436  initial_position = agent;
438 
439  // Initialize random device and generator.
440  std::random_device rd;
441  std::mt19937_64 rng_mt19937_64(rd());
442 
443  // Place wall.
444  while (1){
445  // Random position.
446  mic::types::Position2D wall(0, width-1, 0, height-1);
447 
448  // Validate pose.
449  if ((*environment_grid)({(size_t)wall.x, (size_t)wall.y, (size_t)GridworldChannels::Agent}) != 0)
450  continue;
451 
452  // Add wall...
453  (*environment_grid)({(size_t)wall.x, (size_t)wall.y, (size_t)GridworldChannels::Walls}) = 1;
454  break;
455  }
456 
457  // Place pit.
458  while(1){
459  // Random position.
460  mic::types::Position2D pit(0, width-1, 0, height-1);
461 
462  // Validate pose.
463  if ((*environment_grid)({(size_t)pit.x, (size_t)pit.y, (size_t)GridworldChannels::Agent}) != 0)
464  continue;
465  if ((*environment_grid)({(size_t)pit.x, (size_t)pit.y, (size_t)GridworldChannels::Walls}) != 0)
466  continue;
467 
468  // Add pit...
469  (*environment_grid)({(size_t)pit.x, (size_t)pit.y, (size_t)GridworldChannels::Pits}) = -10;
470 
471  break;
472  }//: while
473 
474 
475  // Place goal.
476  while(1) {
477  // Random position.
478  mic::types::Position2D goal(0, width-1, 0, height-1);
479 
480  // Validate pose.
481  if ((*environment_grid)({(size_t)goal.x, (size_t)goal.y, (size_t)GridworldChannels::Agent}) != 0)
482  continue;
483  if ((*environment_grid)({(size_t)goal.x, (size_t)goal.y, (size_t)GridworldChannels::Walls}) != 0)
484  continue;
485  if ((*environment_grid)({(size_t)goal.x, (size_t)goal.y, (size_t)GridworldChannels::Pits}) != 0)
486  continue;
487 
488  // ... but additionally check the goal surroundings - there must be at least one way out, and not going through the pit!
489  bool reachable = false;
490  for (size_t a=0; a<4; a++){
491  mic::types::NESWAction action(a);
492  mic::types::Position2D way_to_goal = goal + action;
493  if ((isStateAllowed(way_to_goal)) &&
494  ((*environment_grid)({(size_t)way_to_goal.x, (size_t)way_to_goal.y, (size_t)GridworldChannels::Pits}) == 0)) {
495  reachable = true;
496  break;
497  }//: if
498  }//: for
499  if (!reachable)
500  continue;
501 
502  // Ok, add the goal.
503  (*environment_grid)({(size_t)goal.x, (size_t)goal.y, (size_t)GridworldChannels::Goals}) = 10;
504  break;
505  }//: while
506 
507 }
508 
509 bool Gridworld::isGridTraversible(long x_, long y_, mic::types::Matrix<bool> & visited_) {
510  // If not allowed...
511  if (!isStateAllowed(x_, y_))
512  return false;
513  // .. or is a pit...
514  if ((*environment_grid)({(size_t)x_, (size_t)y_, (size_t)GridworldChannels::Pits}) < 0)
515  return false;
516  // ... or wasa already visited.
517  if (visited_(y_,x_))
518  return false;
519  // Ok found the goal!
520  if ((*environment_grid)({(size_t)x_, (size_t)y_, (size_t)GridworldChannels::Goals}) > 0)
521  return true;
522  // Ok, new state.
523  visited_(y_,x_) = true;
524 
525  // Recursive check NESW.
526  if (isGridTraversible(x_, y_-1, visited_))
527  return true;
528  if (isGridTraversible(x_+1, y_, visited_))
529  return true;
530  if (isGridTraversible(x_, y_+1, visited_))
531  return true;
532  if (isGridTraversible(x_-1, y_, visited_))
533  return true;
534  // Sorry, no luck in her.
535  return false;
536 }
537 
538 
539 
541  LOG(LINFO) << "Generating hard " << width << "x" << height<< " random grid";
542 
543  // Set environment_grid size.
544  environment_grid->resize({width, height, channels});
545  environment_grid->zeros();
546 
547  // Place the agent.
548  mic::types::Position2D agent(0, width-1, 0, height-1);
549  initial_position = agent;
551 
552  // Place goal.
553  while(1) {
554  // Random position.
555  mic::types::Position2D goal(0, width-1, 0, height-1);
556 
557  // Validate pose.
558  if ((*environment_grid)({(size_t)goal.x, (size_t)goal.y, (size_t)GridworldChannels::Agent}) != 0)
559  continue;
560 
561  // Ok, add the goal.
562  (*environment_grid)({(size_t)goal.x, (size_t)goal.y, (size_t)GridworldChannels::Goals}) = 10;
563  break;
564  }//: while
565 
566  // Initialize random device and generator.
567  std::random_device rd;
568  std::mt19937_64 rng_mt19937_64(rd());
569 
570 
571  // Initialize uniform integer distribution.
572  size_t max_obstacles = sqrt(width*height) - 2;
573  std::uniform_int_distribution<size_t> obstacle_dist(0, max_obstacles);
574 
575  // Calculate number of walls.
576  size_t number_of_walls = obstacle_dist(rng_mt19937_64);
577 
578  // Matrix informing us thwther we already visited the state or not.
579  mic::types::Matrix<bool> visited (height, width);
580 
581  // Place wall(s).
582  for (size_t i=0; i<number_of_walls; i++) {
583  while (1){
584  // Random position.
585  mic::types::Position2D wall(0, width-1, 0, height-1);
586 
587  // Validate pose.
588  if ((*environment_grid)({(size_t)wall.x, (size_t)wall.y, (size_t)GridworldChannels::Agent}) != 0)
589  continue;
590  if ((*environment_grid)({(size_t)wall.x, (size_t)wall.y, (size_t)GridworldChannels::Goals}) != 0)
591  continue;
592  if ((*environment_grid)({(size_t)wall.x, (size_t)wall.y, (size_t)GridworldChannels::Walls}) != 0)
593  continue;
594 
595  // Add wall...
596  (*environment_grid)({(size_t)wall.x, (size_t)wall.y, (size_t)GridworldChannels::Walls}) = 1;
597 
598  // ... but additionally whether the path from agent to the goal is traversable!
599  visited.setZero();
600  if (!isGridTraversible(agent.x, agent.y, visited)) {
601  // Sorry, we must remove this wall...
602  (*environment_grid)({(size_t)wall.x, (size_t)wall.y, (size_t)GridworldChannels::Walls}) = 0;
603  // .. and try once again.
604  continue;
605  }//: if
606 
607  break;
608  }//: while
609  }//: for number of walls
610 
611 
612  // Calculate number of pits.
613  size_t number_of_pits = obstacle_dist(rng_mt19937_64);
614 
615  // Place pit(s).
616  for (size_t i=0; i<number_of_pits; i++) {
617  while(1){
618  // Random position.
619  mic::types::Position2D pit(0, width-1, 0, height-1);
620 
621  // Validate pose.
622  if ((*environment_grid)({(size_t)pit.x, (size_t)pit.y, (size_t)GridworldChannels::Agent}) != 0)
623  continue;
624  if ((*environment_grid)({(size_t)pit.x, (size_t)pit.y, (size_t)GridworldChannels::Goals}) != 0)
625  continue;
626  if ((*environment_grid)({(size_t)pit.x, (size_t)pit.y, (size_t)GridworldChannels::Pits}) != 0)
627  continue;
628  if ((*environment_grid)({(size_t)pit.x, (size_t)pit.y, (size_t)GridworldChannels::Walls}) != 0)
629  continue;
630 
631  // Add pit...
632  (*environment_grid)({(size_t)pit.x, (size_t)pit.y, (size_t)GridworldChannels::Pits}) = -10;
633 
634  // ... but additionally whether the path from agent to the goal is traversable!
635  visited.setZero();
636  if (!isGridTraversible(agent.x, agent.y, visited)) {
637  // Sorry, we must remove this pit...
638  (*environment_grid)({(size_t)pit.x, (size_t)pit.y, (size_t)GridworldChannels::Pits}) = 0;
639  // .. and try once again.
640  continue;
641  }//: if
642 
643  break;
644  }//: while
645  }//: for number of walls
646 
647 
648 }
649 
650 
651 std::string Gridworld::gridToString(mic::types::TensorXfPtr grid_) {
652  std::string s;
653  // Add line.
654  s+= "+";
655  for (size_t x=0; x<grid_->dim(0); x++)
656  s+="---";
657  s+= "+\n";
658 
659  for (size_t y=0; y<grid_->dim(1); y++){
660  s += "|";
661  for (size_t x=0; x<grid_->dim(0); x++) {
662  // Check object occupancy.
663  if ((*grid_)({x,y, (size_t)GridworldChannels::Agent}) != 0) {
664  // Display agent.
665  s += "<A>";
666  } else if ((*grid_)({x,y, (size_t)GridworldChannels::Walls}) != 0) {
667  // Display wall.
668  s += " # ";
669  } else if ((*grid_)({x,y, (size_t)GridworldChannels::Pits}) < 0) {
670  // Display pit.
671  s += " - ";
672  } else if ((*grid_)({x,y, (size_t)GridworldChannels::Goals}) > 0) {
673  // Display goal.
674  s += " + ";
675  } else
676  s += " ";
677  }//: for x
678  s += "|\n";
679  }//: for y
680 
681  // Add line.
682  s+= "+";
683  for (size_t x=0; x<grid_->dim(0); x++)
684  s+="---";
685  s+= "+\n";
686  return s;
687 }
688 
691 }
692 
694  if (pomdp_flag) {
695  // Get observation.
696  mic::types::TensorXfPtr obs = getObservation();
697  return gridToString(obs);
698  }
699  else
701 }
702 
703 mic::types::MatrixXfPtr Gridworld::encodeEnvironment() {
704  // Temporarily reshape the environment_grid.
705  environment_grid->conservativeResize({1, width * height * channels});
706  // Create a matrix pointer and copy data from grid into the matrix.
707  mic::types::MatrixXfPtr encoded_grid (new mic::types::MatrixXf(*environment_grid));
708  // Back to the original shape.
709  environment_grid->resize({width, height, channels});
710 
711  // Return the matrix pointer.
712  return encoded_grid;
713 }
714 
715 mic::types::MatrixXfPtr Gridworld::encodeObservation() {
716  LOG(LDEBUG) << "encodeObservation()";
717  if (pomdp_flag) {
718  mic::types::Position2D p = getAgentPosition();
719  LOG(LDEBUG) << p;
720 
721  // Get observation.
722  mic::types::TensorXfPtr obs = getObservation();
723  // Temporarily reshape the observation grid.
724  obs->conservativeResize({1, roi_size * roi_size * channels});
725  // Encode the observation.
726  mic::types::MatrixXfPtr encoded_obs (new mic::types::MatrixXf(*obs));
727  // Back to the original shape.
728  obs->conservativeResize({roi_size, roi_size, channels});
729 
730  return encoded_obs;
731  }
732  else
733  return encodeEnvironment();
734 }
735 
736 
737 mic::types::TensorXfPtr Gridworld::getObservation() {
738  LOG(LDEBUG) << "getObservation()";
739  // Reset.
740  observation_grid->zeros();
741 
742  size_t delta = (roi_size-1)/2;
743  mic::types::Position2D p = getAgentPosition();
744 
745  // Copy data.
746  for (long oy=0, ey=(p.y-delta); oy< (long)roi_size; oy++, ey++){
747  for (long ox=0, ex=(p.x-delta); ox< (long)roi_size; ox++, ex++) {
748  // Check grid boundaries.
749  if ((ex < 0) || (ex >= (long)width) || (ey < 0) || (ey >= (long)height)){
750  // Place the wall only
751  (*observation_grid)({(size_t)ox, (size_t)oy, (size_t)GridworldChannels::Walls}) = 1;
752  continue;
753  }//: if
754  // Else : copy data for all channels.
755  (*observation_grid)({(size_t)ox,(size_t)oy, (size_t)GridworldChannels::Goals}) = (*environment_grid)({(size_t)ex,(size_t)ey, (size_t)GridworldChannels::Goals});
756  (*observation_grid)({(size_t)ox,(size_t)oy, (size_t)GridworldChannels::Pits}) = (*environment_grid)({(size_t)ex,(size_t)ey, (size_t)GridworldChannels::Pits});
757  (*observation_grid)({(size_t)ox,(size_t)oy, (size_t)GridworldChannels::Walls}) = (*environment_grid)({(size_t)ex,(size_t)ey, (size_t)GridworldChannels::Walls});
758  (*observation_grid)({(size_t)ox,(size_t)oy, (size_t)GridworldChannels::Agent}) = (*environment_grid)({(size_t)ex,(size_t)ey, (size_t)GridworldChannels::Agent});
759  }//: for x
760  }//: for y
761 
762  //LOG(LDEBUG) << std::endl << gridToString(observation_grid);
763 
764  return observation_grid;
765 }
766 
767 
768 mic::types::MatrixXfPtr Gridworld::encodeAgentGrid() {
769  // DEBUG - copy only agent pose data, avoid goals etc.
770  mic::types::MatrixXfPtr encoded_grid (new mic::types::MatrixXf(height, width));
771  encoded_grid->setZero();
772 
773  for (size_t y=0; y<height; y++){
774  for (size_t x=0; x<width; x++) {
775  // Check object occupancy.
776  if ((*environment_grid)({x,y, (size_t)GridworldChannels::Agent}) != 0) {
777  // Set one.
778  (*encoded_grid)(y,x) = 1;
779  break;
780  }
781  }//: for x
782  }//: for y
783  encoded_grid->resize(height*width, 1);
784 
785  // Return the matrix pointer.
786  return encoded_grid;
787 }
788 
789 
790 mic::types::Position2D Gridworld::getAgentPosition() {
791  mic::types::Position2D position;
792  for (size_t y=0; y<height; y++){
793  for (size_t x=0; x<width; x++) {
794  if ((*environment_grid)({x,y, (size_t)GridworldChannels::Agent}) == 1) {
795  position.x = x;
796  position.y = y;
797  return position;
798  }// if
799  }//: for x
800  }//: for y
801  // Remove warnings...
802  return position;
803 }
804 
805 bool Gridworld::moveAgentToPosition(mic::types::Position2D pos_) {
806  LOG(LDEBUG) << "New agent position = " << pos_;
807 
808  // Check whether the state is allowed.
809  if (!isStateAllowed(pos_))
810  return false;
811 
812  // Clear old.
813  mic::types::Position2D old = getAgentPosition();
814  (*environment_grid)({(size_t)old.x, (size_t)old.y, (size_t)GridworldChannels::Agent}) = 0;
815  // Set new.
816  (*environment_grid)({(size_t)pos_.x, (size_t)pos_.y, (size_t)GridworldChannels::Agent}) = 1;
817 
818  return true;
819 }
820 
821 
822 
823 float Gridworld::getStateReward(mic::types::Position2D pos_) {
824  // Check reward - goal or pit.
825  if ((*environment_grid)({(size_t)pos_.x, (size_t)pos_.y, (size_t)GridworldChannels::Pits}) != 0)
826  return (*environment_grid)({(size_t)pos_.x, (size_t)pos_.y, (size_t)GridworldChannels::Pits});
827  else if ((*environment_grid)({(size_t)pos_.x, (size_t)pos_.y, (size_t)GridworldChannels::Goals}) != 0)
828  return (*environment_grid)({(size_t)pos_.x, (size_t)pos_.y, (size_t)GridworldChannels::Goals});
829  else
830  return 0;
831 }
832 
833 
834 bool Gridworld::isStateAllowed(mic::types::Position2D pos_) {
835  if ((pos_.x < 0) || (pos_.x >= (long)width))
836  return false;
837 
838  if ((pos_.y < 0) || (pos_.y >= (long)height))
839  return false;
840 
841  // Check walls!
842  if ((*environment_grid)({(size_t)pos_.x, (size_t)pos_.y, (size_t)GridworldChannels::Walls}) != 0)
843  return false;
844 
845  return true;
846 }
847 
848 
849 bool Gridworld::isStateTerminal(mic::types::Position2D pos_) {
850  if ((pos_.x < 0) || (pos_.x >= (long)width))
851  return false;
852 
853  if ((pos_.y < 0) || (pos_.y >= (long)height))
854  return false;
855 
856  // Check reward - goal or pit.
857  if ((*environment_grid)({(size_t)pos_.x, (size_t)pos_.y, (size_t)GridworldChannels::Pits}) != 0)
858  return true;
859  else if ((*environment_grid)({(size_t)pos_.x, (size_t)pos_.y, (size_t)GridworldChannels::Goals}) != 0)
860  return true;
861  else
862 
863  return false;
864 }
865 
866 
867 } /* namespace environments */
868 } /* namespace mic */
virtual bool moveAgentToPosition(mic::types::Position2D pos_)
Definition: Gridworld.cpp:805
virtual float getStateReward(mic::types::Position2D pos_)
Definition: Gridworld.cpp:823
Class emulating the gridworld environment.
Definition: Gridworld.hpp:50
mic::types::TensorXfPtr getObservation()
Definition: Gridworld.cpp:737
Abstract class representing an environment.
Definition: Environment.hpp:40
std::string gridToString(mic::types::TensorXfPtr grid_)
Definition: Gridworld.cpp:651
virtual mic::types::MatrixXfPtr encodeObservation()
Definition: Gridworld.cpp:715
size_t channels
Number of channels.
virtual mic::types::MatrixXfPtr encodeEnvironment()
Definition: Gridworld.cpp:703
virtual mic::types::Position2D getAgentPosition()
Definition: Gridworld.cpp:790
bool isGridTraversible(long x_, long y_, mic::types::Matrix< bool > &visited_)
Definition: Gridworld.cpp:509
mic::configuration::Property< size_t > width
Property: width of the environment.
virtual bool isStateTerminal(mic::types::Position2D pos_)
Definition: Gridworld.cpp:849
bool pomdp_flag
Flag related to.
Channel storing the agent position.
virtual mic::types::MatrixXfPtr encodeAgentGrid()
Encode the current state of the reduced grid (only the agent position) as a matrix of size [1...
Definition: Gridworld.cpp:768
Gridworld(std::string node_name_="gridworld")
Definition: Gridworld.cpp:28
mic::environments::Gridworld & operator=(const mic::environments::Gridworld &gw_)
Definition: Gridworld.cpp:61
mic::configuration::Property< size_t > roi_size
Property: size of the ROI (region of interest).
virtual std::string environmentToString()
Definition: Gridworld.cpp:689
mic::types::TensorXfPtr observation_grid
mic::types::TensorXfPtr environment_grid
Tensor storing the environment.
virtual bool isStateAllowed(mic::types::Position2D pos_)
Definition: Gridworld.cpp:834
virtual void initializePropertyDependentVariables()
Definition: Gridworld.cpp:77
mic::types::Position2D initial_position
Property: initial position of the agent.
mic::configuration::Property< size_t > height
Property: height of the environment.
virtual void initializeEnvironment()
Definition: Gridworld.cpp:81
mic::configuration::Property< short > type
Definition: Gridworld.hpp:294
virtual std::string observationToString()
Definition: Gridworld.cpp:693