MachineIntelligenceCore:NeuralNets
 All Classes Namespaces Files Functions Variables Enumerations Enumerator Friends Macros
HebbianRuleTests.cpp
Go to the documentation of this file.
1 
25 #include "HebbianRule.hpp"
26 
27 #include <gtest/gtest.h>
28 #include <cmath>
29 
30 // Redefine word "public" so every class field/method will be accessible for tests.
31 #define private public
33 
34 
39 TEST(HebbianRule, Weight3x4Update) {
40  // Input and outputs.
41  mic::types::MatrixPtr<double> x = MAKE_MATRIX_PTR(double, 3, 2);
42  (*x) << 0, 1, 1, 0, 0, 0;
43  mic::types::MatrixPtr<double> y = MAKE_MATRIX_PTR(double, 4, 2);
44  (*y) << 0, 1, 1, 0, 0, 1, 1, 1;
45 
46  // Desired result - delta.
47  mic::types::MatrixPtr<double> result_delta = MAKE_MATRIX_PTR(double, 4, 3);
48  (*result_delta) << 0.001, 0, 0, 0, 0.001, 0, 0.001, 0, 0, 0.001, 0.001, 0;
49 
50  // Rule.
51  mic::neural_nets::learning::HebbianRule<double> hebb(result_delta->rows(), result_delta->cols());
52  // Calculate update.
53  mic::types::MatrixPtr<double> delta = hebb.calculateUpdate(x, y, 0.001);
54 
55 /* std::cout << "x = \n" << (*x) << std::endl;
56  std::cout << "y = \n" << (*y) << std::endl;
57  std::cout << "delta = \n" << (*delta) << std::endl;
58  std::cout << "result_delta = \n" << (*result_delta) << std::endl;*/
59 
60  // Check dimensions.
61  ASSERT_EQ(delta->rows(), result_delta->rows());
62  ASSERT_EQ(delta->cols(), result_delta->cols());
63 
64  // Before update.
65  for (size_t i=0; i< (size_t)result_delta->size(); i++)
66  ASSERT_LE(std::fabs((*delta)[i] - (*result_delta)[i]), 0) << " at element i=" << i;
67 
68  // After update.
69  hebb.update(result_delta, x, y, 0.001);
70  for (size_t i=0; i< (size_t)result_delta->size(); i++)
71  ASSERT_LE(std::fabs((*result_delta)[i] - 2*(*delta)[i]), 0) << " at element i=" << i;
72 }
73 
74 
75 
76 int main(int argc, char **argv) {
77  testing::InitGoogleTest(&argc, argv);
78  return RUN_ALL_TESTS();
79 }
80 
81 
Updates according to classical Hebbian rule (wij += ni * x * y).
Definition: HebbianRule.hpp:43
virtual mic::types::MatrixPtr< eT > calculateUpdate(mic::types::MatrixPtr< eT > x_, mic::types::MatrixPtr< eT > y_, eT learning_rate_)
Definition: HebbianRule.hpp:65
TEST(HebbianRule, Weight3x4Update)
int main(int argc, char **argv)