26 #ifndef OPTIMIZATIONFUNCTIONS_HPP_
27 #define OPTIMIZATIONFUNCTIONS_HPP_
29 #include <types/MatrixTypes.hpp>
32 namespace neural_nets {
33 namespace optimization {
40 template <
typename eT=
float>
58 virtual void update(mic::types::MatrixPtr<eT> p_, mic::types::MatrixPtr<eT> dp_, eT learning_rate_, eT decay_ = 0.0) {
61 mic::types::MatrixPtr<eT> delta =
calculateUpdate(p_, dp_, learning_rate_);
66 for (
size_t i=0; i< (size_t)delta->size(); i++) {
67 (*p_)[i] = (1.0f - decay_) * (*p_)[i] - (*delta)[i];
78 virtual void update(mic::types::MatrixPtr<eT> p_, mic::types::MatrixPtr<eT> x_, mic::types::MatrixPtr<eT> y_, eT learning_rate_ = 0.001) {
79 assert(p_->rows() == y_->rows());
80 assert(p_->cols() == x_->rows());
81 assert(x_->cols() == y_->cols());
84 mic::types::MatrixPtr<eT> delta =
calculateUpdate(x_, y_, learning_rate_);
97 virtual mic::types::MatrixPtr<eT>
calculateUpdate(mic::types::MatrixPtr<eT> x_, mic::types::MatrixPtr<eT> dx_, eT learning_rate_) = 0;
virtual ~OptimizationFunction()
Virtual destructor - empty.
Abstract class representing interface to optimization function.
virtual mic::types::MatrixPtr< eT > calculateUpdate(mic::types::MatrixPtr< eT > x_, mic::types::MatrixPtr< eT > dx_, eT learning_rate_)=0
virtual void update(mic::types::MatrixPtr< eT > p_, mic::types::MatrixPtr< eT > x_, mic::types::MatrixPtr< eT > y_, eT learning_rate_=0.001)
virtual void update(mic::types::MatrixPtr< eT > p_, mic::types::MatrixPtr< eT > dp_, eT learning_rate_, eT decay_=0.0)