23 #ifndef SRC_TYPES_BATCH_HPP_
24 #define SRC_TYPES_BATCH_HPP_
43 template<
typename DataType,
typename LabelType>
98 std::vector <std::shared_ptr<DataType> > &
data() {
103 std::vector <std::shared_ptr<LabelType> > &
labels() {
117 std::shared_ptr<DataType>
data(
size_t index_) {
126 std::shared_ptr<LabelType>
labels(
size_t index_) {
171 std::uniform_int_distribution<> index_dist(0, this->
sample_data.size()-1);
173 std::vector<size_t> tmp_indices;
227 for (
size_t local_index: indices_) {
248 for (
size_t local_index: indices_) {
273 std::uniform_int_distribution<> index_dist(0, this->
sample_data.size()-1);
312 if ((index_ < 0) || (index_ >=
sample_data.size())){
314 throw std::out_of_range(
"Sample index out of range!");
319 if ((sample_index < 0) || (sample_index >=
sample_data.size())){
321 throw std::out_of_range(
"Invalid sample index!");
325 std::shared_ptr<DataType> data_ptr =
sample_data[sample_index];
326 std::shared_ptr<LabelType> label_ptr =
sample_labels[sample_index];
343 throw std::out_of_range(
"Sample index out of range!");
347 std::shared_ptr<DataType> data_ptr =
sample_data[index_];
348 std::shared_ptr<LabelType> label_ptr =
sample_labels[index_];
390 virtual void add(std::shared_ptr<DataType> data_, std::shared_ptr<LabelType> label_,
size_t index_) {
402 virtual void add(std::shared_ptr<DataType> data_, std::shared_ptr<LabelType> label_) {
416 classes.push_back(*label);
445 std::random_device
rd;
std::vector< std::shared_ptr< DataType > > & data()
Returns sample data.
std::vector< size_t > indices(size_t index_)
std::shared_ptr< LabelType > labels(size_t index_)
void setBatchSize(size_t batch_size_)
Template class storing the sample batches. A batch is stored in fact as three vectors, containing data, labels and sample numbers respectively.
std::shared_ptr< DataType > data(size_t index_)
mic::types::Batch< DataType, LabelType > getBatch(std::vector< size_t > indices_)
std::vector< size_t > sample_indices
Stores sample indices (sample "positions" in original dataset).
std::vector< std::shared_ptr< LabelType > > & labels()
Returns sample labels.
mic::types::Batch< DataType, LabelType > getNextBatch()
mic::types::Sample< DataType, LabelType > getSample(size_t index_)
virtual void add(std::shared_ptr< DataType > data_, std::shared_ptr< LabelType > label_, size_t index_)
virtual void add(mic::types::Sample< DataType, LabelType > sample_)
size_t index() const
Returns the sample number (the sample "position" in original dataset).
mic::types::Batch< DataType, LabelType > getRandomBatch()
Template class storing the data-label pairs. Additionally it stores the the index of the sample (main...
mic::types::Batch< DataType, LabelType > getBatchDirect(std::vector< size_t > indices_)
std::vector< std::shared_ptr< DataType > > sample_data
Stores the data.
std::shared_ptr< LabelType > label() const
Returns the returns the sample label.
std::vector< size_t > & indices()
Returns sample numbers (sample "positions" in original dataset).
mic::types::Sample< DataType, LabelType > getNextSample()
Batch(const mic::types::Batch< DataType, LabelType > &batch_)
mic::types::Sample< DataType, LabelType > getSampleDirect(size_t index_)
mic::types::Batch< DataType, LabelType > & operator=(const mic::types::Batch< DataType, LabelType > &batch_)
std::vector< std::shared_ptr< LabelType > > sample_labels
Stores labels.
Batch(size_t batch_size_=1)
mic::types::Sample< DataType, LabelType > getRandomSample()
virtual void add(std::shared_ptr< DataType > data_, std::shared_ptr< LabelType > label_)
std::shared_ptr< DataType > data() const
Returns the sample data.
std::mt19937_64 rng_mt19937_64
void setNextSampleIndex(size_t index_=0)