MachineIntelligenceCore:Algorithms
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
Batch.hpp
Go to the documentation of this file.
1 
23 #ifndef SRC_TYPES_BATCH_HPP_
24 #define SRC_TYPES_BATCH_HPP_
25 
26 #include <types/Sample.hpp>
27 
28 #include <random>
29 
30 #include <list>
31 
32 
33 namespace mic {
34 namespace types {
35 
43 template<typename DataType, typename LabelType>
44 class Batch {
45 public:
49  Batch(size_t batch_size_ = 1) :
51  batch_size(batch_size_),
53  {
54 
55  }
56 
63  {
64  // Copy parameters.
66  batch_size = batch_.batch_size;
67 
68  // Copy "samples".
69  this->sample_data = batch_.sample_data;
70  this->sample_labels = batch_.sample_labels;
71  this->sample_indices = batch_.sample_indices;
72  // And number of classes.
73  this->number_of_classes = batch_.number_of_classes;
74  }
75 
82  // Copy "samples".
83  this->sample_data = batch_.sample_data;
84  this->sample_labels = batch_.sample_labels;
85  this->sample_indices = batch_.sample_indices;
86  // And number of classes.
87  this->number_of_classes = batch_.number_of_classes;
88  // Return the object.
89  return *this;
90  }
91 
95  virtual ~Batch() { };
96 
98  std::vector <std::shared_ptr<DataType> > & data() {
99  return sample_data;
100  }
101 
103  std::vector <std::shared_ptr<LabelType> > & labels() {
104  return sample_labels;
105  }
106 
108  std::vector <size_t> & indices() {
109  return sample_indices;
110  }
111 
117  std::shared_ptr<DataType> data(size_t index_) {
118  return sample_data[index_];
119  }
120 
126  std::shared_ptr<LabelType> labels(size_t index_) {
127  return sample_labels[index_];
128  }
129 
135  std::vector <size_t> indices(size_t index_) {
136  return sample_indices[index_];
137  }
138 
143  size_t size() {
144  return sample_data.size();
145  }
146 
151  void setBatchSize(size_t batch_size_) {
152  batch_size = batch_size_;
153  }
154 
155 
159  size_t getBatchSize() {
160  return batch_size;
161  }
162 
163 
169 
170  // Initialize uniform index distribution - integers.
171  std::uniform_int_distribution<> index_dist(0, this->sample_data.size()-1);
172 
173  std::vector<size_t> tmp_indices;
174  for (size_t i=0; i<batch_size; i++) {
175  // Pick an index.
176  tmp_indices.push_back((size_t)index_dist(rng_mt19937_64));
177 
178  }//: batch_size
179 
180  // Return data.
181  return this->getBatchDirect(tmp_indices);
182  }
183 
184 
192 
193  // Check index.
194  if((next_sample_index+batch_size) > this->sample_data.size()){
195  // Reset index.
196  next_sample_index = 0;
197  }
198  // Generate list of indices.
199  std::vector<size_t> indices;
200  for (size_t i=0; i<batch_size; i++) {
201  // Pick an index.
202  indices.push_back((size_t)(next_sample_index+i));
203 
204  }//: batch_size
205 
206  // Increment index.
207  next_sample_index = (size_t)next_sample_index + (size_t)batch_size;
208  // Return data.
209  return this->getBatchDirect(indices);
210  }
211 
212 
219  mic::types::Batch<DataType, LabelType> getBatch(std::vector<size_t> indices_) {
220 
221  // New empty batch.
223  // Set number of classes.
225 
226  // For all indices.
227  for (size_t local_index: indices_) {
228  // Try to add sample to batch.
229  batch.add(getSample(local_index));
230  }
231  // Return batch.
232  return batch;
233  }
234 
235 
243 
244  // New empty batch.
246 
247  // For all indices.
248  for (size_t local_index: indices_) {
249  // Try to add sample to batch.
250  batch.add(getSampleDirect(local_index));
251  }
252  // Return batch.
253  return batch;
254  }
255 
256 
261  bool isLastBatch() {
262  return ((next_sample_index+batch_size) >= this->sample_data.size());
263  }
264 
265 
271 
272  // Initialize uniform index distribution - integers.
273  std::uniform_int_distribution<> index_dist(0, this->sample_data.size()-1);
274 
275  // Pick an index.
276  unsigned int tmp_index= index_dist(rng_mt19937_64);
277 
278  //LOG(LDEBUG) << "data size = " << this->sample_data.size() << " labels size = " << this->sample_labels.size() << " index = " << tmp_index;
279 
280  // Return data.
281  return this->getSampleDirect(tmp_index);
282  }
283 
291 
292  // Check index.
293  if(next_sample_index >= this->sample_data.size()){
294  // Reset index.
295  next_sample_index = 0;
296  }
297  // Return given sample and increment index afterwards.
299  next_sample_index += 1;
300  return sample;
301  }
302 
303 
310 
311  // Check index.
312  if ((index_ < 0) || (index_ >= sample_data.size())){
313  // Reset index.
314  throw std::out_of_range("Sample index out of range!");
315  }//: if
316 
317  // Check sample index.
318  size_t sample_index = sample_indices[index_];
319  if ((sample_index < 0) || (sample_index >= sample_data.size())){
320  // Reset index.
321  throw std::out_of_range("Invalid sample index!");
322  }//: if
323 
324  // Get data
325  std::shared_ptr<DataType> data_ptr = sample_data[sample_index];
326  std::shared_ptr<LabelType> label_ptr = sample_labels[sample_index];
327 
328  // Return data.
329  return Sample<DataType, LabelType> (data_ptr, label_ptr, sample_index);
330  }
331 
339 
340  // Check index.
341  if (index_ >= sample_data.size()){
342  // Reset index.
343  throw std::out_of_range("Sample index out of range!");
344  }//: if
345 
346  // Get data
347  std::shared_ptr<DataType> data_ptr = sample_data[index_];
348  std::shared_ptr<LabelType> label_ptr = sample_labels[index_];
349  size_t sample_index = sample_indices[index_];
350 
351  // Return data.
352  return Sample<DataType, LabelType> (data_ptr, label_ptr, sample_index);
353  }
354 
355 
360  void setNextSampleIndex(size_t index_ = 0) {
361  next_sample_index = index_;
362  }
363 
368  bool isLastSample() {
369  return (next_sample_index >= this->sample_data.size());
370  }
371 
372 
378  // Add sample to vectors.
379  sample_data.push_back(sample_.data());
380  sample_labels.push_back(sample_.label());
381  sample_indices.push_back(sample_.index());
382  }
383 
390  virtual void add(std::shared_ptr<DataType> data_, std::shared_ptr<LabelType> label_, size_t index_) {
391  // Add sample to vectors.
392  sample_data.push_back(data_);
393  sample_labels.push_back(label_);
394  sample_indices.push_back(index_);
395  }
396 
402  virtual void add(std::shared_ptr<DataType> data_, std::shared_ptr<LabelType> label_) {
403  // Add sample to vectors.
404  sample_data.push_back(data_);
405  sample_labels.push_back(label_);
406  sample_indices.push_back(sample_indices.size());
407  }
408 
413  void countClasses() {
414  std::list<LabelType> classes;
415  for(auto label: this->sample_labels) {
416  classes.push_back(*label);
417  }//: for
418  classes.sort();
419  classes.unique();
420  this->number_of_classes = classes.size();
421  }
422 
427  size_t classes() {
428  return number_of_classes;
429  }
430 
431 protected:
436 
440  size_t batch_size;
441 
445  std::random_device rd;
446 
450  std::mt19937_64 rng_mt19937_64;
451 
452 
454  std::vector <std::shared_ptr<DataType> > sample_data;
455 
457  std::vector <std::shared_ptr<LabelType> > sample_labels;
458 
460  std::vector <size_t> sample_indices;
461 
466 
467 };
468 
469 
470 } /* namespace types */
471 } /* namespace mic */
472 
473 
474 
475 
476 #endif /* SRC_TYPES_BATCH_HPP_ */
size_t number_of_classes
Definition: Batch.hpp:465
std::vector< std::shared_ptr< DataType > > & data()
Returns sample data.
Definition: Batch.hpp:98
std::vector< size_t > indices(size_t index_)
Definition: Batch.hpp:135
std::shared_ptr< LabelType > labels(size_t index_)
Definition: Batch.hpp:126
size_t size()
Definition: Batch.hpp:143
void setBatchSize(size_t batch_size_)
Definition: Batch.hpp:151
Template class storing the sample batches. A batch is stored in fact as three vectors, containing data, labels and sample numbers respectively.
Definition: Batch.hpp:44
std::shared_ptr< DataType > data(size_t index_)
Definition: Batch.hpp:117
size_t getBatchSize()
Definition: Batch.hpp:159
bool isLastSample()
Definition: Batch.hpp:368
mic::types::Batch< DataType, LabelType > getBatch(std::vector< size_t > indices_)
Definition: Batch.hpp:219
std::vector< size_t > sample_indices
Stores sample indices (sample "positions" in original dataset).
Definition: Batch.hpp:460
std::vector< std::shared_ptr< LabelType > > & labels()
Returns sample labels.
Definition: Batch.hpp:103
mic::types::Batch< DataType, LabelType > getNextBatch()
Definition: Batch.hpp:191
mic::types::Sample< DataType, LabelType > getSample(size_t index_)
Definition: Batch.hpp:309
virtual void add(std::shared_ptr< DataType > data_, std::shared_ptr< LabelType > label_, size_t index_)
Definition: Batch.hpp:390
virtual void add(mic::types::Sample< DataType, LabelType > sample_)
Definition: Batch.hpp:377
size_t index() const
Returns the sample number (the sample "position" in original dataset).
Definition: Sample.hpp:90
mic::types::Batch< DataType, LabelType > getRandomBatch()
Definition: Batch.hpp:168
Template class storing the data-label pairs. Additionally it stores the the index of the sample (main...
Definition: Sample.hpp:38
mic::types::Batch< DataType, LabelType > getBatchDirect(std::vector< size_t > indices_)
Definition: Batch.hpp:242
std::vector< std::shared_ptr< DataType > > sample_data
Stores the data.
Definition: Batch.hpp:454
size_t batch_size
Definition: Batch.hpp:440
bool isLastBatch()
Definition: Batch.hpp:261
std::shared_ptr< LabelType > label() const
Returns the returns the sample label.
Definition: Sample.hpp:85
size_t classes()
Definition: Batch.hpp:427
std::vector< size_t > & indices()
Returns sample numbers (sample "positions" in original dataset).
Definition: Batch.hpp:108
mic::types::Sample< DataType, LabelType > getNextSample()
Definition: Batch.hpp:290
std::random_device rd
Definition: Batch.hpp:445
Batch(const mic::types::Batch< DataType, LabelType > &batch_)
Definition: Batch.hpp:61
mic::types::Sample< DataType, LabelType > getSampleDirect(size_t index_)
Definition: Batch.hpp:338
size_t next_sample_index
Definition: Batch.hpp:435
virtual ~Batch()
Definition: Batch.hpp:95
mic::types::Batch< DataType, LabelType > & operator=(const mic::types::Batch< DataType, LabelType > &batch_)
Definition: Batch.hpp:81
std::vector< std::shared_ptr< LabelType > > sample_labels
Stores labels.
Definition: Batch.hpp:457
Batch(size_t batch_size_=1)
Definition: Batch.hpp:49
mic::types::Sample< DataType, LabelType > getRandomSample()
Definition: Batch.hpp:270
virtual void add(std::shared_ptr< DataType > data_, std::shared_ptr< LabelType > label_)
Definition: Batch.hpp:402
std::shared_ptr< DataType > data() const
Returns the sample data.
Definition: Sample.hpp:80
std::mt19937_64 rng_mt19937_64
Definition: Batch.hpp:450
void setNextSampleIndex(size_t index_=0)
Definition: Batch.hpp:360
void countClasses()
Definition: Batch.hpp:413