MachineIntelligenceCore:Algorithms
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
MNISTMatrixImporter.hpp
Go to the documentation of this file.
1 
24 #ifndef SRC_importers_MNISTMATRIXIMPORTER_HPP_
25 #define SRC_importers_MNISTMATRIXIMPORTER_HPP_
26 
27 #include <importers/Importer.hpp>
28 #include <types/MNISTTypes.hpp>
29 #include <fstream>
30 
31 namespace mic {
32 namespace importers {
33 
34 template<typename T=float>
35 class MNISTMatrixImporter: public mic::importers::Importer< mic::types::Matrix<T>, unsigned int > {
36 public:
43  MNISTMatrixImporter(std::string node_name_ = "mnist_matrix_importer", std::string data_filename_ = "", std::string labels_filename_ = "") :
44  Importer< mic::types::Matrix<T>, unsigned int > (node_name_),
45  data_filename("data_filename", data_filename_),
46  labels_filename("labels_filename", labels_filename_),
47  samples_limit("samples_limit", -1)
48  {
49  // Register properties - so their values can be overridden (read from the configuration file).
50  this->registerProperty(data_filename);
51  this->registerProperty(labels_filename);
52  this->registerProperty(samples_limit);
53 
54  // Set image properties.
55  image_width = 28;
56  image_height = 28;
57  }
61  virtual ~MNISTMatrixImporter() {}
62 
67  void setDataFilename(std::string data_filename_) {
68  data_filename = data_filename_;
69  }
70 
75  void setLabelsFilename(std::string labels_filename_){
76  labels_filename = labels_filename_;
77  }
78 
83  bool importData(){
84 
85  char buffer[28*28];
86  int label_offset_bytes = 8;
87  int data_offset_bytes = 16;
88  size_t sample = 0;
89 
90  // Try to open file with labels.
91  LOG(LSTATUS) << "Opening file containing MNIST labels: " << labels_filename;
92  std::ifstream labels_file(labels_filename, std::ios::in | std::ios::binary);
93  if (!labels_file.is_open()) {
94  LOG(LFATAL) << "Oops! Couldn't find file: " << labels_filename;
95  return false;
96  }//: else
97 
98  // Read file containing images (binary format).
99  LOG(LSTATUS) << "Opening file containing MNIST images: " << data_filename;
100  std::ifstream data_file(data_filename, std::ios::in | std::ios::binary);
101  if (!data_file.is_open()) {
102  LOG(LFATAL) << "Oops! Couldn't find file: " << data_filename;
103  return false;
104  }
105 
106  // Label and image files ok - import digits.
107  LOG(LSTATUS) << "Importing MNIST digits. This might take a while...";
108 
109  // Skip label header.
110  labels_file.seekg (label_offset_bytes, std::ios::beg);
111  // Skip data header.
112  data_file.seekg (data_offset_bytes , std::ios::beg);
113 
114  // Import loop.
115  while(true) {
116  // Try to read the label.
117  labels_file.read(buffer, 1);
118  // If reached the EOF.
119  if (labels_file.eof())
120  break;
121  // Else: get the label.
122  unsigned int temp_label = (unsigned int)buffer[0];
123 
124  // Try to read the image into buffer.
125  data_file.read(buffer, image_width*image_height);
126  // If reached the EOF.
127  if (data_file.eof())
128  break;
129  // Else: get image.
130 
131  // Create new matrix of MNIST image size.
132  mic::types::MatrixPtr<T> image_ptr (new mic::types::Matrix<T>(image_height, image_width));
133 
134  // Parse and set image data.
135  for (size_t i = 0; i < (size_t)(image_width*image_height); i++) {
136  unsigned row = i / image_width;
137  unsigned col = i % image_height;
138  (*image_ptr)(row, col) = (T)((uint8_t)buffer[i])/(T)255.0f;
139  }//: for
140 
141  // Got the image and label.
142  LOG(LDEBUG) << "Loading MNIST sample: " << sample;
143 
144  sample_data.push_back(image_ptr);
145  sample_labels.push_back(std::make_shared <unsigned int> (temp_label) );
146 
147  sample++;
148  // Check limit.
149  if ((samples_limit > 0) && (sample >= (size_t)samples_limit))
150  break;
151  }//: while !eof
152 
153  LOG(LINFO) << "Imported " << sample_labels.size() << " patches";
154 
155  // Close files
156  labels_file.close();
157  data_file.close();
158 
159  // Fill the indices table(!)
160  for (size_t i=0; i < sample_data.size(); i++ )
161  sample_indices.push_back(i);
162 
163  // Count the classes.
164  //countClasses();
165  number_of_classes = 10;
166 
167  LOG(LINFO) << "Data import finished";
168  return true;
169  }
170 
175 
176 private:
181 
186 
190  mic::configuration::Property<std::string> data_filename;
191 
195  mic::configuration::Property<std::string> labels_filename;
196 
200  mic::configuration::Property<int> samples_limit;
201 
202  using Importer< mic::types::Matrix<T>, unsigned int >::sample_data;
206 };
207 
208 
209 } /* namespace importers */
210 } /* namespace mic */
211 
212 
213 #endif /* SRC_importers_MNISTMATRIXIMPORTER_HPP_ */
void setLabelsFilename(std::string labels_filename_)
mic::configuration::Property< std::string > data_filename
std::vector< size_t > sample_indices
Stores sample indices (sample "positions" in original dataset).
Definition: Batch.hpp:460
MNISTMatrixImporter(std::string node_name_="mnist_matrix_importer", std::string data_filename_="", std::string labels_filename_="")
Parent class for all data importers.
Definition: Importer.hpp:51
std::vector< std::shared_ptr< mic::types::Matrix< T > > > sample_data
Stores the data.
Definition: Batch.hpp:454
Contains declaration (and definition) of base template class of all data importers.
mic::configuration::Property< std::string > labels_filename
typename std::shared_ptr< mic::types::Matrix< T > > MatrixPtr
Typedef for a shared pointer to template-typed dynamic matrices.
Definition: Matrix.hpp:479
void setDataFilename(std::string data_filename_)
std::vector< std::shared_ptr< unsigned int > > sample_labels
Stores labels.
Definition: Batch.hpp:457
mic::configuration::Property< int > samples_limit
Template-typed Matrix of dynamic size. Uses OpenBLAS if found by CMAKE - overloaded, specializations of * operator for types: float, double.
Definition: Matrix.hpp:64