MachineIntelligenceCore:NeuralNets
 All Classes Namespaces Files Functions Variables Enumerations Enumerator Friends Macros
AdaGrad.hpp
Go to the documentation of this file.
1 
25 #ifndef ADAGRAD_HPP_
26 #define ADAGRAD_HPP_
27 
29 
30 namespace mic {
31 namespace neural_nets {
32 namespace optimization {
33 
38 template <typename eT=float>
39 class AdaGrad : public OptimizationFunction<eT> {
40 public:
41 
47  AdaGrad(size_t rows_, size_t cols_, eT eps_ = 1e-8) : eps(eps_) {
48  G = MAKE_MATRIX_PTR(eT, rows_, cols_);
49  // Reset G.
50  G->zeros();
51 
52  // Allocate and reset delta.
53  delta = MAKE_MATRIX_PTR(eT, rows_, cols_);
54  delta->zeros();
55  }
56 
63  mic::types::MatrixPtr<eT> calculateUpdate(mic::types::MatrixPtr<eT> x_, mic::types::MatrixPtr<eT> dx_, eT learning_rate_) {
64  assert(x_->size() == dx_->size());
65  assert(x_->size() == G->size());
66 
67  // Update G - add square of the gradients.
68  for (size_t i=0; i<(size_t)x_->size(); i++)
69  (*G)[i] += (*dx_)[i] * (*dx_)[i];
70 
71  // delta = alpha * dW.
72  for (size_t i=0; i<(size_t)x_->size(); i++)
73  (*delta)[i] = learning_rate_ * (*dx_)[i] / (std::sqrt((*G)[i] + eps));
74 
75  // Return the update.
76  return delta;
77  }
78 
79 protected:
81  eT eps;
82 
84  mic::types::MatrixPtr<eT> G;
85 
87  mic::types::MatrixPtr<eT> delta;
88 };
89 
90 } //: optimization
91 } //: neural_nets
92 } //: mic
93 
94 #endif /* ADAGRAD_HPP_ */
AdaGrad(size_t rows_, size_t cols_, eT eps_=1e-8)
Definition: AdaGrad.hpp:47
Abstract class representing interface to optimization function.
mic::types::MatrixPtr< eT > G
Sum of all of the squares of the gradients up to time t ("diagonal matrix").
Definition: AdaGrad.hpp:84
eT eps
Smoothing term that avoids division by zero.
Definition: AdaGrad.hpp:81
mic::types::MatrixPtr< eT > calculateUpdate(mic::types::MatrixPtr< eT > x_, mic::types::MatrixPtr< eT > dx_, eT learning_rate_)
Definition: AdaGrad.hpp:63
Update using AdaGrad - adaptive gradient descent.
Definition: AdaGrad.hpp:39
mic::types::MatrixPtr< eT > delta
Calculated update.
Definition: AdaGrad.hpp:87