23 #include <logger/Log.hpp>
24 #include <logger/ConsoleOutput.hpp>
25 using namespace mic::logger;
31 #include <encoders/MatrixXfMatrixXfEncoder.hpp>
32 #include <encoders/UIntMatrixXfEncoder.hpp>
34 #include <types/Batch.hpp>
38 using namespace mic::mlnn;
39 using namespace mic::types;
43 LOGGER->addOutput(
new ConsoleOutput());
46 size_t dataset_size = 15;
48 Batch<MatrixXf, MatrixXf> dataset;
49 for(
size_t i=0; i< dataset_size; i++) {
51 MatrixXfPtr pose (
new MatrixXf(dataset_size, 1));
54 dataset.data().push_back(pose);
57 MatrixXfPtr target (
new MatrixXf(classes, 1));
59 (*target)(i%classes,0)= 1;
60 dataset.labels().push_back(target);
63 dataset.indices().push_back(i);
76 mic::encoders::MatrixXfMatrixXfEncoder data_encoder(dataset_size, 1);
77 mic::encoders::MatrixXfMatrixXfEncoder
label_encoder(classes, 1);
81 while (iteration < 10000) {
82 Batch <MatrixXf, MatrixXf> batch = dataset.getRandomBatch();
85 MatrixXfPtr encoded_batch, encoded_targets;
86 encoded_batch = data_encoder.encodeBatch(batch.data());
87 encoded_targets = label_encoder.encodeBatch(batch.labels());
90 float loss = nn.
train (encoded_batch, encoded_targets, 0.1);
92 if (iteration % 1000 == 0){
93 std::cout<<
"[" << iteration <<
"]: Loss : " << loss << std::endl;
105 dataset.setNextSampleIndex(0);
106 while (iteration < dataset_size) {
107 Sample <MatrixXf, MatrixXf> sample = dataset.getNextSample();
108 std::cout <<
"[" << iteration++ <<
"]: sample (" << sample.index() <<
"): "<< sample.data()->transpose() <<
"->" << sample.label()->transpose() << std::endl;
110 float loss = nn.
test(sample.data(), sample.label());
113 std::cout<<
"Loss : " << loss << std::endl;
114 std::cout<<
"Targets : " << sample.label()->transpose() << std::endl;
115 std::cout<<
"Predictions : " << predictions.transpose() << std::endl << std::endl;
mic::encoders::UIntMatrixXfEncoder * label_encoder
Label 2 matrix encoder (1 hot).
mic::types::MatrixPtr< eT > getPredictions()
eT test(mic::types::MatrixPtr< eT > encoded_batch_, mic::types::MatrixPtr< eT > encoded_targets_)
eT train(mic::types::MatrixPtr< eT > encoded_batch_, mic::types::MatrixPtr< eT > encoded_targets_, eT learning_rate_, eT decay_=0.0f)
Adam - adaptive moment estimation.
void pushLayer(LayerType *layer_ptr_)