28 #include <types/MatrixTypes.hpp>
31 namespace neural_nets {
40 template <
typename dtype=
float>
46 virtual dtype
calculateLoss (mic::types::MatrixPtr<dtype> target_y_, mic::types::MatrixPtr<dtype> predicted_y_) = 0;
51 virtual dtype
calculateMeanLoss (mic::types::MatrixPtr<dtype> target_y_, mic::types::MatrixPtr<dtype> predicted_y_) {
52 return calculateLoss(target_y_, predicted_y_) / predicted_y_->cols();
58 virtual mic::types::MatrixPtr<dtype>
calculateGradient (mic::types::MatrixPtr<dtype> target_y_, mic::types::MatrixPtr<dtype> predicted_y_) = 0;
virtual dtype calculateLoss(mic::types::MatrixPtr< dtype > target_y_, mic::types::MatrixPtr< dtype > predicted_y_)=0
Function calculating loss - abstract.
Abstract class representing a loss function. Defines interfaces.
virtual mic::types::MatrixPtr< dtype > calculateGradient(mic::types::MatrixPtr< dtype > target_y_, mic::types::MatrixPtr< dtype > predicted_y_)=0
Function calculating gradient - abstract.
virtual dtype calculateMeanLoss(mic::types::MatrixPtr< dtype > target_y_, mic::types::MatrixPtr< dtype > predicted_y_)
Calculates mean loss (i.e. divides the loss by the size of batch) - ACE for cross-entropy or MSE for ...