MachineIntelligenceCore:NeuralNets
 All Classes Namespaces Files Functions Variables Enumerations Enumerator Friends Macros
CrossEntropyLoss.hpp
Go to the documentation of this file.
1 
27 #ifndef CROSSENTROPYLOSS_HPP_
28 #define CROSSENTROPYLOSS_HPP_
29 
30 #include <cmath>
31 #include <loss/Loss.hpp>
32 
33 namespace mic {
34 namespace neural_nets {
35 namespace loss {
36 
42 template <typename dtype=float>
43 class CrossEntropyLoss : public Loss<dtype> {
44 public:
48  dtype calculateLoss (mic::types::MatrixPtr<dtype> target_y_, mic::types::MatrixPtr<dtype> predicted_y_) {
49  // Sizes must match.
50  assert(predicted_y_->size() == target_y_->size());
51 
52  // Calculate loss (negative log probability).
53  dtype loss =0;
54  dtype eps = 1e-15;
55  for (size_t i=0; i <(size_t)predicted_y_->size(); i++) {
56  // -t * log (y + eps!)
57  loss -= (*target_y_)[i] * std::log2((*predicted_y_)[i] + eps);
58  }
59  // Return cross-entropy error (CE).
60  // The average cross entropy error (ACE) is loss divided by the batch size.
61  //assert(std::isfinite(loss));
62  return loss;
63  }
64 
68  mic::types::MatrixPtr<dtype> calculateGradient (mic::types::MatrixPtr<dtype> target_y_, mic::types::MatrixPtr<dtype> predicted_y_) {
69  // Sizes must match.
70  assert(predicted_y_->size() == target_y_->size());
71 
72  // Calculate gradient.
73  mic::types::MatrixPtr<dtype> dy = MAKE_MATRIX_PTR(dtype, predicted_y_->rows(), predicted_y_->cols());
74  for (size_t i=0; i <(size_t)predicted_y_->size(); i++) {
75  // y - t
76  (*dy)[i] = (*predicted_y_)[i] - (*target_y_)[i];
77  }
78  return dy;
79  }
80 
81 };
82 
83 } //: loss
84 } //: neural_nets
85 } //: mic
86 
87 #endif /* CROSSENTROPYLOSS_HPP_ */
Class representing a cross-entropy loss function (classification).
mic::types::MatrixPtr< dtype > calculateGradient(mic::types::MatrixPtr< dtype > target_y_, mic::types::MatrixPtr< dtype > predicted_y_)
Gradient calculation for cross-entropy.
Abstract class representing a loss function. Defines interfaces.
Definition: Loss.hpp:41
dtype calculateLoss(mic::types::MatrixPtr< dtype > target_y_, mic::types::MatrixPtr< dtype > predicted_y_)
Calculates cross entropy(using log) and returns cross-entropy error (CE).