MachineIntelligenceCore:NeuralNets
 All Classes Namespaces Files Functions Variables Enumerations Enumerator Friends Macros
Momentum.hpp
Go to the documentation of this file.
1 
25 #ifndef MOMENTUM_HPP_
26 #define MOMENTUM_HPP_
27 
29 
30 namespace mic {
31 namespace neural_nets {
32 namespace optimization {
33 
38 template <typename eT=float>
39 class Momentum : public OptimizationFunction<eT> {
40 public:
41 
47  Momentum(size_t rows_, size_t cols_, eT momentum_ = 0.9) : momentum(momentum_) {
48  v = MAKE_MATRIX_PTR(eT, rows_, cols_);
49  // Reset momentum.
50  v->zeros();
51  }
52 
59  mic::types::MatrixPtr<eT> calculateUpdate(mic::types::MatrixPtr<eT> x_, mic::types::MatrixPtr<eT> dx_, eT learning_rate_ = 0.001) {
60  assert(x_->size() == dx_->size());
61  assert(x_->size() == v->size());
62 
63  // Calculate the update vector (delta).
64  for (size_t i=0; i<(size_t)x_->size(); i++)
65  (*v)[i] = momentum * (*v)[i] + learning_rate_ * (*dx_)[i];
66 
67  // Return the update.
68  return v;
69  }
70 
71 protected:
73  mic::types::MatrixPtr<eT> v;
74 
77 };
78 
79 } //: optimization
80 } //: neural_nets
81 } //: mic
82 
83 #endif /* MOMENTUM_HPP_ */
mic::types::MatrixPtr< eT > v
Update vector.
Definition: Momentum.hpp:73
Abstract class representing interface to optimization function.
mic::types::MatrixPtr< eT > calculateUpdate(mic::types::MatrixPtr< eT > x_, mic::types::MatrixPtr< eT > dx_, eT learning_rate_=0.001)
Definition: Momentum.hpp:59
Momentum(size_t rows_, size_t cols_, eT momentum_=0.9)
Definition: Momentum.hpp:47
Update in the direction of gradient descent - with momentum.
Definition: Momentum.hpp:39