MachineIntelligenceCore:ReinforcementLearning
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator
GridworldValueIteration.cpp
Go to the documentation of this file.
1 
24 #include <limits>
25 #include <utils/RandomGenerator.hpp>
27 
28 namespace mic {
29 namespace application {
30 
35 void RegisterApplication (void) {
36  REGISTER_APPLICATION(mic::application::GridworldValueIteration);
37 }
38 
39 
40 GridworldValueIteration::GridworldValueIteration(std::string node_name_) : Application(node_name_),
41  step_reward("step_reward", 0.0),
42  discount_rate("discount_rate", 0.9),
43  move_noise("move_noise",0.2),
44  statistics_filename("statistics_filename","statistics_filename.csv")
45 
46  {
47  // Register properties - so their values can be overridden (read from the configuration file).
48  registerProperty(step_reward);
49  registerProperty(discount_rate);
50  registerProperty(move_noise);
51  registerProperty(statistics_filename);
52 
53  LOG(LINFO) << "Properties registered";
54 }
55 
56 
58 
59 }
60 
61 
62 void GridworldValueIteration::initialize(int argc, char* argv[]) {
63 
64 }
65 
67  // Initialize the gridworld.
69 
70  // Resize and reset the action-value table.
72  //state_value_table.zeros();
73  state_value_table.setValue( -std::numeric_limits<float>::infinity() );
74  running_delta = -std::numeric_limits<float>::infinity();
75 
76  LOG(LSTATUS) << std::endl << streamStateActionTable();
77 }
78 
79 
80 
82  std::ostringstream os;
83  for (size_t y=0; y<grid_env.getEnvironmentHeight(); y++){
84  os << "| ";
85  for (size_t x=0; x<grid_env.getEnvironmentWidth(); x++) {
86  if ( state_value_table(y,x) == -std::numeric_limits<float>::infinity())
87  os << "-INF | ";
88  else
89  os << state_value_table(y,x) << " | ";
90  }//: for x
91  os << std::endl;
92  }//: for y
93  return os.str();
94 
95 }
96 
97 
98 
99 float GridworldValueIteration::computeQValueFromValues(mic::types::Position2D pos_, mic::types::NESWAction ac_){
100  // Compute the Q-value of action in state from the value function stored table.
101  mic::types::Position2D new_pos = pos_ + ac_;
102  float q_value = (1-move_noise)*(step_reward + discount_rate * state_value_table((size_t)new_pos.y, (size_t)new_pos.x));
103  float probs_normalizer = (1-move_noise);
104 
105  // Consider also east and west actions as possible actions - due to move_noise.
106  if ((ac_.getType() == types::NESW::North) || (ac_.getType() == types::NESW::South)) {
107  if (grid_env.isActionAllowed(pos_, A_EAST)) {
108  mic::types::Position2D east_pos = pos_ + A_EAST;
109  if (state_value_table((size_t)east_pos.y, (size_t)east_pos.x) != -std::numeric_limits<float>::infinity()) {
110  q_value += (move_noise/2)*(step_reward + discount_rate * state_value_table( (size_t)east_pos.y, (size_t)east_pos.x));
111  probs_normalizer += (move_noise/2);
112  }//:if != -INF
113  }//: if
114  if (grid_env.isActionAllowed(pos_, A_WEST)) {
115  mic::types::Position2D west_pos = pos_ + A_WEST;
116  if (state_value_table((size_t)west_pos.y, (size_t)west_pos.x) != -std::numeric_limits<float>::infinity()) {
117  q_value += (move_noise/2)*(step_reward + discount_rate * state_value_table((size_t)west_pos.y, (size_t)west_pos.x));
118  probs_normalizer += (move_noise/2);
119  }//:if != -INF
120  }//: if
121  }//: if
122 
123  // Consider also north and south actions as possible actions - due to move_noise.
124  if ((ac_.getType() == types::NESW::East) || (ac_.getType() == types::NESW::West)) {
125  if (grid_env.isActionAllowed(pos_, A_NORTH)) {
126  mic::types::Position2D north_pos = pos_ + A_NORTH;
127  if (state_value_table((size_t)north_pos.y, (size_t)north_pos.x) != -std::numeric_limits<float>::infinity()) {
128  q_value += (move_noise/2)*(step_reward + discount_rate * state_value_table((size_t)north_pos.y, (size_t)north_pos.x));
129  probs_normalizer += (move_noise/2);
130  }//:if != -INF
131  }//: if
132  if (grid_env.isActionAllowed(pos_, A_SOUTH)) {
133  mic::types::Position2D south_pos = pos_ + A_SOUTH;
134  if (state_value_table((size_t)south_pos.y, (size_t)south_pos.x) != -std::numeric_limits<float>::infinity()) {
135  q_value += (move_noise/2)*(step_reward + discount_rate * state_value_table((size_t)south_pos.y, (size_t)south_pos.x));
136  probs_normalizer += (move_noise/2);
137  }//:if != -INF
138  }//: if
139  }//: if
140 
141  // Normalize the probabilities.
142  q_value /= probs_normalizer;
143 
144  return q_value;
145 }
146 
147 float GridworldValueIteration::computeBestValue(mic::types::Position2D pos_){
148  float best_value = -std::numeric_limits<float>::infinity();
149  // Check if the state is allowed.
150  if (!grid_env.isStateAllowed(pos_))
151  return best_value;
152 
153  // Create a list of possible actions.
154  std::vector<mic::types::NESWAction> actions;
155  actions.push_back(A_NORTH);
156  actions.push_back(A_EAST);
157  actions.push_back(A_SOUTH);
158  actions.push_back(A_WEST);
159 
160  // Check the actions one by one.
161  for(mic::types::NESWAction action : actions) {
162  if(grid_env.isActionAllowed(pos_, action)) {
163  float value = computeQValueFromValues(pos_, action);
164  if (value > best_value)
165  best_value = value;
166  }//if is allowed
167  }//: for
168 
169  return best_value;
170 }
171 
172 
174  LOG(LTRACE) << "Performing a single step (" << iteration << ")";
175 
176  // Perform the iterative policy iteration.
177  mic::types::MatrixXf new_state_value_table(grid_env.getEnvironmentHeight(), grid_env.getEnvironmentWidth());
178  new_state_value_table.setValue( -std::numeric_limits<float>::infinity() );
179 
180  for (size_t y=0; y<grid_env.getEnvironmentHeight(); y++){
181  for (size_t x=0; x<grid_env.getEnvironmentWidth(); x++) {
182  mic::types::Position2D pos(x,y);
183  if (grid_env.isStateTerminal(pos) ) {
184  // Set the state rewared.
185  new_state_value_table((size_t)pos.y, (size_t)pos.x) = grid_env.getStateReward(pos);
186  continue;
187  }//: if
188  // Else - compute the best value.
189  if (grid_env.isStateAllowed(pos) )
190  new_state_value_table((size_t)pos.y, (size_t)pos.x) = computeBestValue(pos);
191  }//: for x
192  }//: for y
193 
194  // Compute delta.
195  mic::types::MatrixXf delta_value;
196  float curr_delta = 0;
197  for (size_t i =0; i < (size_t) grid_env.getEnvironmentWidth() * grid_env.getEnvironmentHeight(); i++){
198  float tmp_delta = 0;
199  if (std::isfinite(new_state_value_table(i)))
200  tmp_delta += new_state_value_table(i);
201  if (std::isfinite(state_value_table(i)))
202  tmp_delta -= state_value_table(i);
203  curr_delta += std::abs(tmp_delta);
204  }//: for
205  running_delta = curr_delta;
206 
207  // Update state.
208  state_value_table = new_state_value_table;
209 
210  LOG(LSTATUS) << std::endl << grid_env.environmentToString();
211  LOG(LSTATUS) << std::endl << streamStateActionTable();
212  LOG(LINFO) << "Delta Value = " << running_delta;
213 
214  if (running_delta < 1e-05)
215  return false;
216 
217  return true;
218 }
219 
220 
221 
222 } /* namespace application */
223 } /* namespace mic */
mic::environments::Gridworld grid_env
The gridworld object.
virtual float getStateReward(mic::types::Position2D pos_)
Definition: Gridworld.cpp:823
GridworldValueIteration(std::string node_name_="application")
mic::types::MatrixXf state_value_table
Matrix storing values for all states (gridworld w * h). ROW MAJOR(!).
mic::configuration::Property< float > discount_rate
virtual bool isActionAllowed(long x_, long y_, size_t action_)
Definition: Environment.cpp:70
virtual bool isStateTerminal(mic::types::Position2D pos_)
Definition: Gridworld.cpp:849
mic::configuration::Property< float > move_noise
virtual size_t getEnvironmentWidth()
Definition: Environment.hpp:69
float computeQValueFromValues(mic::types::Position2D pos_, mic::types::NESWAction ac_)
mic::configuration::Property< std::string > statistics_filename
Property: name of the file to which the statistics will be exported.
Class responsible for solving the gridworld problem by applying the reinforcement learning value iter...
virtual std::string environmentToString()
Definition: Gridworld.cpp:689
virtual size_t getEnvironmentHeight()
Definition: Environment.hpp:75
virtual void initialize(int argc, char *argv[])
virtual bool isStateAllowed(mic::types::Position2D pos_)
Definition: Gridworld.cpp:834
void RegisterApplication(void)
Registers application.
float computeBestValue(mic::types::Position2D pos_)
virtual void initializeEnvironment()
Definition: Gridworld.cpp:81
mic::configuration::Property< float > step_reward
Declaration of the application class responsible for solving the gridworld problem with value iterati...