MachineIntelligenceCore:ReinforcementLearning
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator
MNISTDigit.cpp
Go to the documentation of this file.
1 
23 #include <types/MNISTDigit.hpp>
24 #include <utils/RandomGenerator.hpp>
25 
26 namespace mic {
27 namespace environments {
28 
29 MNISTDigit::MNISTDigit(std::string node_name_) : Environment(node_name_),
30  mnist_importer("mnist_importer"),
31  sample_number("sample_number", 0),
32  agent_x("agent_x",-1),
33  agent_y("agent_y",-1),
34  goal_x("goal_x",-1),
35  goal_y("goal_y",-1)
36 {
37  // Register properties - so their values can be overridden (read from the configuration file).
38  registerProperty(sample_number);
39  registerProperty(agent_x);
40  registerProperty(agent_y);
41  registerProperty(goal_x);
42  registerProperty(goal_y);
43 
45 
46 }
47 
49  // TODO Auto-generated destructor stub
50 }
51 
53  width = md_.width;
54  height = md_.height;
55  channels = md_.channels;
59 
60  return *this;
61 }
62 
63 
64 // Initialize environment_grid.
66  // Load dataset.
67  if (!mnist_importer.importData()) {
68  //return;
69  }
70 
71  // Set environment size.
72  width = height = 28;
74 
75  // Check whether it is a POMDP or not.
76  if (roi_size >0) {
77  pomdp_flag = true;
79  } else {
81  }//: else
82 
83 }
84 
86  // Reset the grid.
87  environment_grid->zeros();
88 
89  if (mnist_importer.size() > 0) {
90  mic::types::MNISTSample<> sample;
91 // LOG(LERROR) << "sample_number: " << sample_number;
92 
93  if (sample_number >= mnist_importer.size())
94  // Random select sample from dataset.
95  sample = mnist_importer.getRandomSample();
96  else
97  // Get given sample.
98  sample = mnist_importer.getSample(sample_number);
99 
100  for (size_t x=0; x<width; x++)
101  for (size_t y=0; y<height; y++)
102  (*environment_grid)({x,y,(size_t)MNISTDigitChannels::Pixels}) = (*sample.data())(y,x);
103  LOG(LINFO) << "Digit: " << (*sample.label());
104  }
105 
106  // Put goal and agent.
107  size_t ax,ay,gx,gy;
108 
109  // Set agent coordinates.
110  ax = (agent_x < 0) ? RAN_GEN->uniRandInt(0,width) : agent_x;
111  ay = (agent_y < 0) ? RAN_GEN->uniRandInt(0,height) : agent_y;
112  // Set initial position
113  initial_position.set(ax,ay);
115 
116  // Set goal coordinates.
117  gx = (goal_x < 0) ? RAN_GEN->uniRandInt(0,width) : goal_x;
118  gy = (goal_y < 0) ? RAN_GEN->uniRandInt(0,height) : goal_y;
119  (*environment_grid)({gx,gy,(size_t)MNISTDigitChannels::Goals}) = 10;
120 
121  // Calculate the optimal path length.
122  optimal_path_length = abs((int)ax-(int)gx) + abs((int)ay-(int)gy);
123 }
124 
125 std::string MNISTDigit::toString(mic::types::TensorXfPtr env_) {
126  std::string s;
127  // Add line.
128  s+= "+";
129  for (size_t x=0; x<env_->dim(0); x++)
130  s+="---";
131  s+= "+\n";
132 
133  for (size_t y=0; y<env_->dim(1); y++){
134  s += "|";
135  for (size_t x=0; x<env_->dim(0); x++) {
136  // Check object occupancy.
137  if ((*env_)({x,y, (size_t)MNISTDigitChannels::Agent}) != 0) {
138  // Display agent.
139  s += "<A>";
140  } else if ((*env_)({x,y, (size_t)MNISTDigitChannels::Goals}) > 0) {
141  // Display goal.
142  s += " + ";
143  } else {
144  // Display "image patch" - pixel.
145  //((*env_)({x,y, (size_t)MNISTDigitChannels::Pixels})
146  s += " ";
147  }
148  }//: for x
149  s += "|\n";
150  }//: for y
151 
152  // Add line.
153  s+= "+";
154  for (size_t x=0; x<env_->dim(0); x++)
155  s+="---";
156  s+= "+\n";
157  return s;
158 }
159 
161  return toString(environment_grid);
162 }
163 
165  if (pomdp_flag) {
166  // Get observation.
167  mic::types::TensorXfPtr obs = getObservation();
168  return toString(obs);
169  }
170  else
171  return toString(environment_grid);
172 }
173 
174 mic::types::MatrixXfPtr MNISTDigit::encodeEnvironment() {
175  // Temporarily reshape the environment_grid.
176  environment_grid->conservativeResize({1, width * height * channels});
177  // Create a matrix pointer and copy data from grid into the matrix.
178  mic::types::MatrixXfPtr encoded_grid (new mic::types::MatrixXf(*environment_grid));
179  // Back to the original shape.
180  environment_grid->resize({width, height, channels});
181 
182  // Return the matrix pointer.
183  return encoded_grid;
184 }
185 
186 mic::types::MatrixXfPtr MNISTDigit::encodeObservation() {
187  LOG(LDEBUG) << "encodeObservation()";
188  if (pomdp_flag) {
189  mic::types::Position2D p = getAgentPosition();
190  LOG(LDEBUG) << p;
191 
192  // Get observation.
193  mic::types::TensorXfPtr obs = getObservation();
194  // Temporarily reshape the observation grid.
195  obs->conservativeResize({1, roi_size * roi_size * channels});
196  // Encode the observation.
197  mic::types::MatrixXfPtr encoded_obs (new mic::types::MatrixXf(*obs));
198  // Back to the original shape.
199  obs->conservativeResize({roi_size, roi_size, channels});
200 
201  return encoded_obs;
202  }
203  else
204  return encodeEnvironment();
205 }
206 
207 
208 mic::types::TensorXfPtr MNISTDigit::getObservation() {
209  LOG(LDEBUG) << "getObservation()";
210  // Reset.
211  observation_grid->zeros();
212 
213  size_t delta = (roi_size-1)/2;
214  mic::types::Position2D p = getAgentPosition();
215 
216  // Copy data.
217  for (long oy=0, ey=(p.y-delta); oy<(long)roi_size; oy++, ey++){
218  for (long ox=0, ex=(p.x-delta); ox<(long)roi_size; ox++, ex++) {
219  // Check grid boundaries.
220  if ((ex < 0) || ((size_t)ex >= width) || (ey < 0) || ((size_t)ey >= height)){
221  // Do nothing...
222  continue;
223  }//: if
224  // Else : copy data for all channels.
225  (*observation_grid)({(size_t)ox,(size_t)oy, (size_t)MNISTDigitChannels::Goals}) = (*environment_grid)({(size_t)ex,(size_t)ey, (size_t)MNISTDigitChannels::Goals});
226  (*observation_grid)({(size_t)ox,(size_t)oy, (size_t)MNISTDigitChannels::Pixels}) = (*environment_grid)({(size_t)ex,(size_t)ey, (size_t)MNISTDigitChannels::Pixels});
227  (*observation_grid)({(size_t)ox,(size_t)oy, (size_t)MNISTDigitChannels::Agent}) = (*environment_grid)({(size_t)ex,(size_t)ey, (size_t)MNISTDigitChannels::Agent});
228  }//: for x
229  }//: for y
230 
231  //LOG(LDEBUG) << std::endl << toString(observation_grid);
232 
233  return observation_grid;
234 }
235 
236 
237 
238 mic::types::Position2D MNISTDigit::getAgentPosition() {
239  mic::types::Position2D position;
240  for (size_t y=0; y<height; y++){
241  for (size_t x=0; x<width; x++) {
242  if ((*environment_grid)({x,y, (size_t)MNISTDigitChannels::Agent}) == 1) {
243  position.x = x;
244  position.y = y;
245  return position;
246  }// if
247  }//: for x
248  }//: for y
249  // Remove warnings...
250  return position;
251 }
252 
253 bool MNISTDigit::moveAgentToPosition(mic::types::Position2D pos_) {
254  LOG(LDEBUG) << "New agent position = " << pos_;
255 
256  // Check whether the state is allowed.
257  if (!isStateAllowed(pos_))
258  return false;
259 
260  // Clear old.
261  mic::types::Position2D old = getAgentPosition();
262  (*environment_grid)({(size_t)old.x, (size_t)old.y, (size_t)MNISTDigitChannels::Agent}) = 0;
263  // Set new.
264  (*environment_grid)({(size_t)pos_.x, (size_t)pos_.y, (size_t)MNISTDigitChannels::Agent}) = 1;
265 
266  return true;
267 }
268 
269 
270 
271 float MNISTDigit::getStateReward(mic::types::Position2D pos_) {
272  // Check rewards.
273  if ((*environment_grid)({(size_t)pos_.x, (size_t)pos_.y, (size_t)MNISTDigitChannels::Goals}) != 0)
274  return (*environment_grid)({(size_t)pos_.x, (size_t)pos_.y, (size_t)MNISTDigitChannels::Goals});
275  else
276  return 0;
277 }
278 
279 
280 bool MNISTDigit::isStateAllowed(mic::types::Position2D pos_) {
281  if ((pos_.x < 0) || ((size_t)pos_.x >= width))
282  return false;
283 
284  if ((pos_.y < 0) || ((size_t)pos_.y >= height))
285  return false;
286 
287  return true;
288 }
289 
290 
291 bool MNISTDigit::isStateTerminal(mic::types::Position2D pos_) {
292  if ((pos_.x < 0) || ((size_t)pos_.x >= width))
293  return false;
294 
295  if ((pos_.y < 0) || ((size_t)pos_.y >= height))
296  return false;
297 
298  // Check reward - goal.
299  if ((*environment_grid)({(size_t)pos_.x, (size_t)pos_.y, (size_t)MNISTDigitChannels::Goals}) != 0)
300  return true;
301  else
302 
303  return false;
304 }
305 
306 } /* namespace environments */
307 } /* namespace mic */
mic::types::TensorXfPtr getObservation()
Definition: MNISTDigit.cpp:208
virtual bool moveAgentToPosition(mic::types::Position2D pos_)
Definition: MNISTDigit.cpp:253
virtual bool isStateTerminal(mic::types::Position2D pos_)
Definition: MNISTDigit.cpp:291
virtual std::string observationToString()
Definition: MNISTDigit.cpp:164
Channel storing image intensities (this is a grayscale image)
Abstract class representing an environment.
Definition: Environment.hpp:40
std::string toString(mic::types::TensorXfPtr env_)
Definition: MNISTDigit.cpp:125
virtual void initializeEnvironment()
Definition: MNISTDigit.cpp:85
size_t channels
Number of channels.
virtual mic::types::MatrixXfPtr encodeObservation()
Definition: MNISTDigit.cpp:186
virtual void initializePropertyDependentVariables()
Definition: MNISTDigit.cpp:65
mic::configuration::Property< size_t > sample_number
Definition: MNISTDigit.hpp:171
virtual void moveAgentToInitialPosition()
Definition: Environment.cpp:57
mic::configuration::Property< size_t > width
Property: width of the environment.
MNISTDigit(std::string node_name_="mnist_digit")
Definition: MNISTDigit.cpp:29
Class emulating the MNISTDigit digit environment.
Definition: MNISTDigit.hpp:50
bool pomdp_flag
Flag related to.
Channel storing the agent position.
mic::configuration::Property< short > agent_x
Definition: MNISTDigit.hpp:176
virtual mic::types::Position2D getAgentPosition()
Definition: MNISTDigit.cpp:238
virtual std::string environmentToString()
Definition: MNISTDigit.cpp:160
virtual mic::types::MatrixXfPtr encodeEnvironment()
Definition: MNISTDigit.cpp:174
mic::configuration::Property< size_t > roi_size
Property: size of the ROI (region of interest).
virtual float getStateReward(mic::types::Position2D pos_)
Definition: MNISTDigit.cpp:271
mic::configuration::Property< short > goal_y
Definition: MNISTDigit.hpp:191
mic::environments::MNISTDigit & operator=(const mic::environments::MNISTDigit &md_)
Definition: MNISTDigit.cpp:52
mic::types::TensorXfPtr observation_grid
virtual bool isStateAllowed(mic::types::Position2D pos_)
Definition: MNISTDigit.cpp:280
mic::importers::MNISTMatrixImporter< float > mnist_importer
Importer responsible for loading MNIST dataset.
Definition: MNISTDigit.hpp:166
mic::types::TensorXfPtr environment_grid
Tensor storing the environment.
mic::configuration::Property< short > goal_x
Definition: MNISTDigit.hpp:186
mic::configuration::Property< short > agent_y
Definition: MNISTDigit.hpp:181
mic::types::Position2D initial_position
Property: initial position of the agent.
mic::configuration::Property< size_t > height
Property: height of the environment.