26 namespace environments {
32 registerProperty(
type);
39 type(
"type", gw_.type)
42 registerProperty(
type);
111 LOG(LINFO) <<
"Generating exemplary gridworld";
141 LOG(LINFO) <<
"Generating classic cliff gridworld";
159 for(
size_t x=0; x<
width; x++)
167 LOG(LINFO) <<
"Generating classic discount gridworld";
180 (*environment_grid).zeros();
187 for(
size_t x=0; x<
width; x++)
202 LOG(LINFO) <<
"Generating classic bridge gridworld";
220 for(
size_t x=1; x<width-1; x++) {
238 LOG(LINFO) <<
"Generating classic book environment_grid!!";
267 LOG(LINFO) <<
"Generating classic maze gridworld";
300 LOG(LINFO) <<
"Generating environment_grid from Deep Q-Learning example";
331 LOG(LINFO) <<
"Generating a slightly modified grid from Deep Q-Learning example";
363 LOG(LINFO) <<
"Generating the 2x2 debug grid";
394 LOG(LINFO) <<
"Generating the 3x3 debug grid";
428 LOG(LINFO) <<
"Generating simple " << width <<
"x" <<
height<<
" random grid";
435 mic::types::Position2D agent(0, width-1, 0, height-1);
440 std::random_device rd;
441 std::mt19937_64 rng_mt19937_64(rd());
446 mic::types::Position2D wall(0, width-1, 0, height-1);
460 mic::types::Position2D pit(0, width-1, 0, height-1);
478 mic::types::Position2D goal(0, width-1, 0, height-1);
489 bool reachable =
false;
490 for (
size_t a=0; a<4; a++){
491 mic::types::NESWAction action(a);
492 mic::types::Position2D way_to_goal = goal + action;
523 visited_(y_,x_) =
true;
541 LOG(LINFO) <<
"Generating hard " << width <<
"x" << height<<
" random grid";
548 mic::types::Position2D agent(0, width-1, 0, height-1);
555 mic::types::Position2D goal(0, width-1, 0, height-1);
567 std::random_device rd;
568 std::mt19937_64 rng_mt19937_64(rd());
572 size_t max_obstacles = sqrt(width*height) - 2;
573 std::uniform_int_distribution<size_t> obstacle_dist(0, max_obstacles);
576 size_t number_of_walls = obstacle_dist(rng_mt19937_64);
579 mic::types::Matrix<bool> visited (height, width);
582 for (
size_t i=0; i<number_of_walls; i++) {
585 mic::types::Position2D wall(0, width-1, 0, height-1);
613 size_t number_of_pits = obstacle_dist(rng_mt19937_64);
616 for (
size_t i=0; i<number_of_pits; i++) {
619 mic::types::Position2D pit(0, width-1, 0, height-1);
626 if ((*
environment_grid)({(size_t)pit.x, (
size_t)pit.y, (size_t)GridworldChannels::Pits}) != 0)
632 (*environment_grid)({(size_t)pit.x, (
size_t)pit.y, (size_t)GridworldChannels::Pits}) = -10;
638 (*environment_grid)({(size_t)pit.x, (
size_t)pit.y, (size_t)GridworldChannels::Pits}) = 0;
655 for (
size_t x=0; x<grid_->dim(0); x++)
659 for (
size_t y=0; y<grid_->dim(1); y++){
661 for (
size_t x=0; x<grid_->dim(0); x++) {
669 }
else if ((*grid_)({x,y, (size_t)GridworldChannels::Pits}) < 0) {
683 for (
size_t x=0; x<grid_->dim(0); x++)
707 mic::types::MatrixXfPtr encoded_grid (
new mic::types::MatrixXf(*
environment_grid));
716 LOG(LDEBUG) <<
"encodeObservation()";
726 mic::types::MatrixXfPtr encoded_obs (
new mic::types::MatrixXf(*obs));
738 LOG(LDEBUG) <<
"getObservation()";
746 for (
long oy=0, ey=(p.y-delta); oy< (
long)
roi_size; oy++, ey++){
747 for (
long ox=0, ex=(p.x-delta); ox< (
long)
roi_size; ox++, ex++) {
749 if ((ex < 0) || (ex >= (
long)width) || (ey < 0) || (ey >= (
long)height)){
756 (*observation_grid)({(size_t)ox,(
size_t)oy, (size_t)GridworldChannels::Pits}) = (*environment_grid)({(size_t)ex,(
size_t)ey, (size_t)GridworldChannels::Pits});
770 mic::types::MatrixXfPtr encoded_grid (
new mic::types::MatrixXf(height, width));
771 encoded_grid->setZero();
773 for (
size_t y=0; y<
height; y++){
774 for (
size_t x=0; x<
width; x++) {
778 (*encoded_grid)(y,x) = 1;
783 encoded_grid->resize(height*width, 1);
791 mic::types::Position2D position;
792 for (
size_t y=0; y<
height; y++){
793 for (
size_t x=0; x<
width; x++) {
806 LOG(LDEBUG) <<
"New agent position = " << pos_;
825 if ((*
environment_grid)({(size_t)pos_.x, (
size_t)pos_.y, (
size_t)GridworldChannels::Pits}) != 0)
826 return (*
environment_grid)({(size_t)pos_.x, (
size_t)pos_.y, (size_t)GridworldChannels::Pits});
835 if ((pos_.x < 0) || (pos_.x >= (
long)width))
838 if ((pos_.y < 0) || (pos_.y >= (
long)height))
850 if ((pos_.x < 0) || (pos_.x >= (
long)width))
853 if ((pos_.y < 0) || (pos_.y >= (
long)height))
857 if ((*
environment_grid)({(size_t)pos_.x, (
size_t)pos_.y, (
size_t)GridworldChannels::Pits}) != 0)
void initModifiedDQLGrid()
virtual bool moveAgentToPosition(mic::types::Position2D pos_)
void initClassicCliffGrid()
virtual float getStateReward(mic::types::Position2D pos_)
Class emulating the gridworld environment.
mic::types::TensorXfPtr getObservation()
Abstract class representing an environment.
std::string gridToString(mic::types::TensorXfPtr grid_)
virtual mic::types::MatrixXfPtr encodeObservation()
size_t channels
Number of channels.
virtual mic::types::MatrixXfPtr encodeEnvironment()
virtual mic::types::Position2D getAgentPosition()
bool isGridTraversible(long x_, long y_, mic::types::Matrix< bool > &visited_)
void initSimpleRandomGrid()
void initHardRandomGrid()
mic::configuration::Property< size_t > width
Property: width of the environment.
virtual bool isStateTerminal(mic::types::Position2D pos_)
bool pomdp_flag
Flag related to.
Channel storing the agent position.
void initExemplaryDQLGrid()
virtual mic::types::MatrixXfPtr encodeAgentGrid()
Encode the current state of the reduced grid (only the agent position) as a matrix of size [1...
Gridworld(std::string node_name_="gridworld")
mic::environments::Gridworld & operator=(const mic::environments::Gridworld &gw_)
mic::configuration::Property< size_t > roi_size
Property: size of the ROI (region of interest).
virtual std::string environmentToString()
mic::types::TensorXfPtr observation_grid
mic::types::TensorXfPtr environment_grid
Tensor storing the environment.
virtual bool isStateAllowed(mic::types::Position2D pos_)
virtual void initializePropertyDependentVariables()
mic::types::Position2D initial_position
Property: initial position of the agent.
mic::configuration::Property< size_t > height
Property: height of the environment.
virtual void initializeEnvironment()
mic::configuration::Property< short > type
virtual std::string observationToString()