27 #include <gtest/gtest.h>
31 #define private public
39 TEST(HebbianRule, Weight3x4Update) {
41 mic::types::MatrixPtr<double> x = MAKE_MATRIX_PTR(
double, 3, 2);
42 (*x) << 0, 1, 1, 0, 0, 0;
43 mic::types::MatrixPtr<double> y = MAKE_MATRIX_PTR(
double, 4, 2);
44 (*y) << 0, 1, 1, 0, 0, 1, 1, 1;
47 mic::types::MatrixPtr<double> result_delta = MAKE_MATRIX_PTR(
double, 4, 3);
48 (*result_delta) << 0.001, 0, 0, 0, 0.001, 0, 0.001, 0, 0, 0.001, 0.001, 0;
53 mic::types::MatrixPtr<double> delta = hebb.
calculateUpdate(x, y, 0.001);
61 ASSERT_EQ(delta->rows(), result_delta->rows());
62 ASSERT_EQ(delta->cols(), result_delta->cols());
65 for (
size_t i=0; i< (size_t)result_delta->size(); i++)
66 ASSERT_LE(std::fabs((*delta)[i] - (*result_delta)[i]), 0) <<
" at element i=" << i;
69 hebb.update(result_delta, x, y, 0.001);
70 for (
size_t i=0; i< (size_t)result_delta->size(); i++)
71 ASSERT_LE(std::fabs((*result_delta)[i] - 2*(*delta)[i]), 0) <<
" at element i=" << i;
76 int main(
int argc,
char **argv) {
77 testing::InitGoogleTest(&argc, argv);
78 return RUN_ALL_TESTS();
Updates according to classical Hebbian rule (wij += ni * x * y).
virtual mic::types::MatrixPtr< eT > calculateUpdate(mic::types::MatrixPtr< eT > x_, mic::types::MatrixPtr< eT > y_, eT learning_rate_)
TEST(HebbianRule, Weight3x4Update)
int main(int argc, char **argv)