MachineIntelligenceCore:ReinforcementLearning
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator
MNISTDigit.hpp
Go to the documentation of this file.
1 
23 #ifndef SRC_TYPES_MNISTDIGIT_HPP_
24 #define SRC_TYPES_MNISTDIGIT_HPP_
25 
26 #include <types/Environment.hpp>
27 #include <importers/MNISTMatrixImporter.hpp>
28 
29 
30 namespace mic {
31 namespace environments {
32 
37 enum class MNISTDigitChannels : std::size_t
38 {
39  Pixels = 0,
40  Goals = 1,
41  Agent = 2,
42  Count = 3
43 };
44 
45 
51 public:
56  MNISTDigit(std::string node_name_ = "mnist_digit");
57 
63 
67  virtual ~MNISTDigit();
68 
73 
78 
82  virtual void initializeEnvironment();
83 
88  mic::types::TensorXfPtr getObservation();
89 
94  virtual std::string environmentToString();
95 
100  virtual std::string observationToString();
101 
106  virtual mic::types::MatrixXfPtr encodeEnvironment();
107 
112  virtual mic::types::MatrixXfPtr encodeObservation();
113 
118  virtual mic::types::Position2D getAgentPosition();
119 
126  virtual bool moveAgentToPosition(mic::types::Position2D pos_);
127 
133  virtual float getStateReward(mic::types::Position2D pos_);
134 
135  // Makes all versions of polymorphic method isStateAllowed() available.
137 
143  virtual bool isStateAllowed(mic::types::Position2D pos_);
144 
145  // Makes all versions of polymorphic method isStateTerminal() available.
147 
153  virtual bool isStateTerminal(mic::types::Position2D pos_);
154 
159  unsigned int optimalPathLength(){
160  return optimal_path_length;
161  }
162 
163 protected:
164 
166  mic::importers::MNISTMatrixImporter<float> mnist_importer;
167 
171  mic::configuration::Property<size_t> sample_number;
172 
176  mic::configuration::Property<short> agent_x;
177 
181  mic::configuration::Property<short> agent_y;
182 
186  mic::configuration::Property<short> goal_x;
187 
191  mic::configuration::Property<short> goal_y;
192 
196  unsigned int optimal_path_length;
197 
198 
204  std::string toString(mic::types::TensorXfPtr env_);
205 
206 };
207 
208 } /* namespace environments */
209 } /* namespace mic */
210 
211 #endif /* SRC_TYPES_MNISTDIGIT_HPP_ */
mic::types::TensorXfPtr getObservation()
Definition: MNISTDigit.cpp:208
virtual bool moveAgentToPosition(mic::types::Position2D pos_)
Definition: MNISTDigit.cpp:253
virtual bool isStateTerminal(mic::types::Position2D pos_)
Definition: MNISTDigit.cpp:291
virtual std::string observationToString()
Definition: MNISTDigit.cpp:164
Channel storing image intensities (this is a grayscale image)
Abstract class representing an environment.
Definition: Environment.hpp:40
std::string toString(mic::types::TensorXfPtr env_)
Definition: MNISTDigit.cpp:125
virtual void initializeEnvironment()
Definition: MNISTDigit.cpp:85
virtual bool isStateTerminal(long x_, long y_)
Definition: Environment.cpp:66
virtual mic::types::MatrixXfPtr encodeObservation()
Definition: MNISTDigit.cpp:186
virtual void initializePropertyDependentVariables()
Definition: MNISTDigit.cpp:65
MNISTDigitChannels
MNIST Digit environment channels.
Definition: MNISTDigit.hpp:37
mic::configuration::Property< size_t > sample_number
Definition: MNISTDigit.hpp:171
MNISTDigit(std::string node_name_="mnist_digit")
Definition: MNISTDigit.cpp:29
Class emulating the MNISTDigit digit environment.
Definition: MNISTDigit.hpp:50
Channel storing the agent position.
mic::configuration::Property< short > agent_x
Definition: MNISTDigit.hpp:176
virtual mic::types::Position2D getAgentPosition()
Definition: MNISTDigit.cpp:238
virtual std::string environmentToString()
Definition: MNISTDigit.cpp:160
unsigned int optimalPathLength()
Definition: MNISTDigit.hpp:159
virtual mic::types::MatrixXfPtr encodeEnvironment()
Definition: MNISTDigit.cpp:174
virtual float getStateReward(mic::types::Position2D pos_)
Definition: MNISTDigit.cpp:271
mic::configuration::Property< short > goal_y
Definition: MNISTDigit.hpp:191
mic::environments::MNISTDigit & operator=(const mic::environments::MNISTDigit &md_)
Definition: MNISTDigit.cpp:52
virtual bool isStateAllowed(mic::types::Position2D pos_)
Definition: MNISTDigit.cpp:280
mic::importers::MNISTMatrixImporter< float > mnist_importer
Importer responsible for loading MNIST dataset.
Definition: MNISTDigit.hpp:166
mic::configuration::Property< short > goal_x
Definition: MNISTDigit.hpp:186
mic::configuration::Property< short > agent_y
Definition: MNISTDigit.hpp:181
virtual bool isStateAllowed(long x_, long y_)
Definition: Environment.cpp:62