MachineIntelligenceCore:NeuralNets
 All Classes Namespaces Files Functions Variables Enumerations Enumerator Friends Macros
SparseLinear.hpp
Go to the documentation of this file.
1 
25 #ifndef SPARSELINEAR_H_
26 #define SPARSELINEAR_H_
27 
29 
30 namespace mic {
31 namespace mlnn {
32 namespace fully_connected {
33 
39 template <typename eT=float>
41 public:
49  SparseLinear<eT>(size_t inputs_, size_t outputs_, std::string name_ = "SparseLinear") :
50  Linear<eT>(inputs_, outputs_, name_) {
51 
52  // Change type to SparseLinear.
54 
55  // Prepare matrices in the "temporal memory".
56  // For current sparsity vector.
57  m.add ("ro", outputSize(), 1 );
58  // For penalty.
59  m.add ("penalty", outputSize(), 1 );
60 
61  // Set desired sparsity and penalty term.
62  desired_ro = 0.1; // 10 %
63  beta = 0.5;
64  };
65 
66 
70  virtual ~SparseLinear() {};
71 
75  void backward() {
76  eT eps = 1e-10;
77  // Calculate the current "activation sparsity".
78  mic::types::MatrixPtr<eT> ro = m["ro"];
79  (*ro) = ((*s['y']).rowwise().sum()/batch_size);
80 
81  // Calculate the sparsity penalty - for every output neuron.
82  mic::types::MatrixPtr<eT> penalty = m["penalty"];
83  for (size_t i=0; i<outputSize(); i++)
84  (*penalty)[i] = beta*(-desired_ro/((*ro)[i] + eps) + (1-desired_ro)/(1-(*ro)[i] + eps));
85 
86 
87  // Calculate derivatives of W,b and x.
88  (*g['W']) = (*g['y']) * ((*s['x']).transpose());
89  (*g['b']) = (*g['y']).rowwise().mean();
90  (*g['x']) = (*p['W']).transpose() * (*g['y']);
91  }
92 
98  void update(eT alpha_, eT decay_ = 0.0f) {
99  //std::cout << "p['W'] = \n" << (*p['W']) << std::endl;
100  //std::cout << "g['W'] = \n" << (*g['W']) << std::endl;
101 
102  // Apply selected learning rule to W.
103  opt["W"]->update(p['W'], g['W'], alpha_, decay_);
104 
105  // Apply sparsity learning rule to b, incorporating the KL-divergence term.
106  mic::types::MatrixPtr<eT> penalty = m["penalty"];
107  // (*p['b']) -= alpha_ * beta * (*penalty);
108  opt["b"]->update(p['b'], g['b'], alpha_, 0.0);
109 
110  //std::cout << "p['W'] after update= \n" << (*p['W']) << std::endl;
111  }
112 
113  // Unhide the overloaded methods inherited from the template class Layer fields via "using" statement.
114  using Layer<eT>::forward;
115  using Layer<eT>::backward;
116 
117 protected:
118  // Unhide the fields inherited from the template class Layer via "using" statement.
119  using Layer<eT>::g;
120  using Layer<eT>::s;
121  using Layer<eT>::p;
122  using Layer<eT>::m;
123  using Layer<eT>::inputSize;
124  using Layer<eT>::outputSize;
125  using Layer<eT>::batch_size;
126  using Layer<eT>::opt;
127 
128 private:
129  // Friend class - required for using boost serialization.
130  template<typename tmp> friend class mic::mlnn::MultiLayerNeuralNetwork;
131 
136 
139 
141  eT beta;
142 };
143 
144 
145 } /* namespace fully_connected */
146 } /* namespace mlnn */
147 } /* namespace mic */
148 
149 #endif /* SPARSELINEAR_H_ */
size_t batch_size
Size (length) of (mini)batch.
Definition: Layer.hpp:744
size_t outputSize()
Returns size (length) of outputs.
Definition: Layer.hpp:260
void update(eT alpha_, eT decay_=0.0f)
Class representing a multi-layer neural network.
Definition: Layer.hpp:86
Class implementing a linear, fully connected layer.
Definition: Linear.hpp:42
mic::neural_nets::optimization::OptimizationArray< eT > opt
Array of optimization functions.
Definition: Layer.hpp:765
mic::types::MatrixArray< eT > s
States - contains input [x] and output [y] matrices.
Definition: Layer.hpp:753
mic::types::MatrixArray< eT > g
Gradients - contains input [x] and output [y] matrices.
Definition: Layer.hpp:756
Class implementing a linear, fully connected layer with sparsity regulation.
Definition: Linear.hpp:34
eT desired_ro
Desired sparsity of the layer.
eT beta
Controls the weight of the sparsity penalty term.
mic::types::MatrixArray< eT > m
Memory - a list of temporal parameters, to be used by the derived classes.
Definition: Layer.hpp:762
mic::types::MatrixArray< eT > p
Parameters - parameters of the layer, to be used by the derived classes.
Definition: Layer.hpp:759