MachineIntelligenceCore:NeuralNets
 All Classes Namespaces Files Functions Variables Enumerations Enumerator Friends Macros
Adam.hpp
Go to the documentation of this file.
1 
25 #ifndef ADAM_HPP_
26 #define ADAM_HPP_
27 
29 
30 namespace mic {
31 namespace neural_nets {
32 namespace optimization {
33 
38 template <typename eT=float>
39 class Adam : public OptimizationFunction<eT> {
40 public:
41 
47  Adam(size_t rows_, size_t cols_, eT beta1_ = 0.9, eT beta2_ = 0.999, eT eps_ = 1e-8)
48  : beta1(beta1_), beta2(beta2_), eps(eps_)
49 {
50  m = MAKE_MATRIX_PTR(eT, rows_, cols_);
51  m->zeros();
52 
53  v = MAKE_MATRIX_PTR(eT, rows_, cols_);
54  v->zeros();
55 
56  // Allocate and reset delta.
57  delta = MAKE_MATRIX_PTR(eT, rows_, cols_);
58  delta->zeros();
59 
60  beta1_powt = beta1;
61  beta2_powt = beta2;
62  }
63 
70  mic::types::MatrixPtr<eT> calculateUpdate(mic::types::MatrixPtr<eT> x_, mic::types::MatrixPtr<eT> dx_, eT learning_rate_ = 0.001) {
71  assert(x_->size() == dx_->size());
72  assert(x_->size() == m->size());
73 
74  // Update the decaying average of past gradients.
75  for (size_t i=0; i<(size_t)x_->size(); i++)
76  (*m)[i] = beta1 * (*m)[i] + (1-beta1) * (*dx_)[i];
77 
78  // Update the decaying average of past squared gradients.
79  for (size_t i=0; i<(size_t)x_->size(); i++)
80  (*v)[i] = beta2 * (*v)[i] + (1-beta2) * (*dx_)[i] * (*dx_)[i];
81 
82  // Calculate the update.
83  for (size_t i=0; i<(size_t)x_->size(); i++)
84  (*delta)[i] = learning_rate_ / (sqrt( (*v)[i] / (1 - beta2_powt)) + eps) * (*m)[i] / (1 - beta1_powt);
85 
86  // Update "powered" factors.
87  beta1_powt *= beta1;
88  beta2_powt *= beta2;
89 
90  // Return the update.
91  return delta;
92  }
93 
94 protected:
96  mic::types::MatrixPtr<eT> m;
97 
99  mic::types::MatrixPtr<eT> v;
100 
102  mic::types::MatrixPtr<eT> delta;
103 
105  eT beta1;
106 
108  eT beta2;
109 
111  eT eps;
112 
115 
118 
119 };
120 
121 } //: optimization
122 } //: neural_nets
123 } //: mic
124 
125 #endif /* ADAM_HPP_ */
mic::types::MatrixPtr< eT > calculateUpdate(mic::types::MatrixPtr< eT > x_, mic::types::MatrixPtr< eT > dx_, eT learning_rate_=0.001)
Definition: Adam.hpp:70
Abstract class representing interface to optimization function.
mic::types::MatrixPtr< eT > delta
Calculated update.
Definition: Adam.hpp:102
eT beta2
Decay rate 2 (momentum for past squared gradients).
Definition: Adam.hpp:108
eT eps
Smoothing term that avoids division by zero.
Definition: Adam.hpp:111
eT beta2_powt
Decay rate 2 to the power of t.
Definition: Adam.hpp:117
Adam(size_t rows_, size_t cols_, eT beta1_=0.9, eT beta2_=0.999, eT eps_=1e-8)
Definition: Adam.hpp:47
mic::types::MatrixPtr< eT > m
Exponentially decaying average of past gradients.
Definition: Adam.hpp:96
Adam - adaptive moment estimation.
Definition: Adam.hpp:39
eT beta1
Decay rate 1 (momentum for past gradients).
Definition: Adam.hpp:105
mic::types::MatrixPtr< eT > v
Exponentially decaying average of past squared gradients.
Definition: Adam.hpp:99
eT beta1_powt
Decay rate 1 to the power of t.
Definition: Adam.hpp:114