23 #ifndef SRC_MLNN_SOFTMAX_HPP_
24 #define SRC_MLNN_SOFTMAX_HPP_
30 namespace cost_function {
37 template <
typename eT=
float>
46 Softmax(
size_t size_, std::string name_ =
"Softmax") :
60 Softmax(
size_t height_,
size_t width_,
size_t depth_,
61 std::string name_ =
"Softmax") :
63 height_, width_, depth_,
87 m[
"e"]->resize(
m[
"e"]->rows(), batch_size_);
88 m[
"sum"]->resize(
m[
"sum"]->rows(), batch_size_);
89 m[
"max"]->resize(
m[
"max"]->rows(), batch_size_);
95 mic::types::MatrixPtr<eT> x =
s[
"x"];
96 mic::types::MatrixPtr<eT> y =
s[
"y"];
97 mic::types::MatrixPtr<eT> e =
m[
"e"];
98 mic::types::MatrixPtr<eT> max =
m[
"max"];
99 mic::types::MatrixPtr<eT> sum =
m[
"sum"];
104 (*max) = x->colwise().maxCoeff();
107 for (
size_t i = 0; i < (size_t)y->rows(); i++)
108 for (
size_t j = 0; j < (size_t)y->cols(); j++)
109 (*e)(i, j) = std::exp( (*x)(i, j) - (*max)(j) );
112 (*sum) = e->colwise().sum();
115 for (
size_t i = 0; i < (size_t)y->rows(); i++) {
116 for (
size_t j = 0; j < (size_t)y->cols(); j++) {
117 (*y)(i, j) = (*e)(i, j) / (*sum)(j);
125 mic::types::MatrixPtr<eT> y =
s[
"y"];
126 mic::types::MatrixPtr<eT> dx =
g[
"x"];
127 mic::types::MatrixPtr<eT> dy =
g[
"y"];
130 for (
size_t i = 0; i < (size_t)y->size(); i++)
132 (*dx)[i] = (*dy)[i] * (*y)[i] * (1 - (*y)[i]);
143 virtual void update(eT alpha_, eT decay_ = 0.0f) { };
Softmax(size_t height_, size_t width_, size_t depth_, std::string name_="Softmax")
virtual void resizeBatch(size_t batch_size_)
virtual void resizeBatch(size_t batch_size_)
Softmax activation function.
Class representing a multi-layer neural network.
LayerTypes
Enumeration of possible layer types.
mic::types::MatrixArray< eT > s
States - contains input [x] and output [y] matrices.
mic::types::MatrixArray< eT > g
Gradients - contains input [x] and output [y] matrices.
void forward(bool test_=false)
Softmax(size_t size_, std::string name_="Softmax")
virtual void update(eT alpha_, eT decay_=0.0f)
Contains a template class representing a layer.
mic::types::MatrixArray< eT > m
Memory - a list of temporal parameters, to be used by the derived classes.