MachineIntelligenceCore:NeuralNets
 All Classes Namespaces Files Functions Variables Enumerations Enumerator Friends Macros
OptimizationFunction.hpp
Go to the documentation of this file.
1 
26 #ifndef OPTIMIZATIONFUNCTIONS_HPP_
27 #define OPTIMIZATIONFUNCTIONS_HPP_
28 
29 #include <types/MatrixTypes.hpp>
30 
31 namespace mic {
32 namespace neural_nets {
33 namespace optimization {
34 
40 template <typename eT=float>
42 public:
47 
49  virtual ~OptimizationFunction () { }
50 
58  virtual void update(mic::types::MatrixPtr<eT> p_, mic::types::MatrixPtr<eT> dp_, eT learning_rate_, eT decay_ = 0.0) {
59 
60  // Calculate the update.
61  mic::types::MatrixPtr<eT> delta = calculateUpdate(p_, dp_, learning_rate_);
62 
63  //assert(std::isfinite((*delta)[i]));
64 
65  // Perform the update: x = x - delta (with optional weight decay).
66  for (size_t i=0; i< (size_t)delta->size(); i++) {
67  (*p_)[i] = (1.0f - decay_) * (*p_)[i] - (*delta)[i];
68  }//: for
69  }
70 
78  virtual void update(mic::types::MatrixPtr<eT> p_, mic::types::MatrixPtr<eT> x_, mic::types::MatrixPtr<eT> y_, eT learning_rate_ = 0.001) {
79  assert(p_->rows() == y_->rows());
80  assert(p_->cols() == x_->rows());
81  assert(x_->cols() == y_->cols());
82 
83  // Calculate the update using hebbian "fire together, wire together".
84  mic::types::MatrixPtr<eT> delta = calculateUpdate(x_, y_, learning_rate_);
85 
86  // weight += delta;
87  (*p_) += (*delta);
88  }
89 
90 
97  virtual mic::types::MatrixPtr<eT> calculateUpdate(mic::types::MatrixPtr<eT> x_, mic::types::MatrixPtr<eT> dx_, eT learning_rate_) = 0;
98 
99 
100 
101 };
102 
103 
104 } //: optimization
105 } //: neural_nets
106 } //: mic
107 
108 
109 #endif /* OPTIMIZATIONFUNCTIONS_HPP_ */
virtual ~OptimizationFunction()
Virtual destructor - empty.
Abstract class representing interface to optimization function.
virtual mic::types::MatrixPtr< eT > calculateUpdate(mic::types::MatrixPtr< eT > x_, mic::types::MatrixPtr< eT > dx_, eT learning_rate_)=0
virtual void update(mic::types::MatrixPtr< eT > p_, mic::types::MatrixPtr< eT > x_, mic::types::MatrixPtr< eT > y_, eT learning_rate_=0.001)
virtual void update(mic::types::MatrixPtr< eT > p_, mic::types::MatrixPtr< eT > dp_, eT learning_rate_, eT decay_=0.0)