MachineIntelligenceCore:NeuralNets
 All Classes Namespaces Files Functions Variables Enumerations Enumerator Friends Macros
mic::neural_nets::loss::CrossEntropyLoss< dtype > Class Template Reference

Class representing a cross-entropy loss function (classification). More...

#include <CrossEntropyLoss.hpp>

Inheritance diagram for mic::neural_nets::loss::CrossEntropyLoss< dtype >:
Collaboration diagram for mic::neural_nets::loss::CrossEntropyLoss< dtype >:

Public Member Functions

dtype calculateLoss (mic::types::MatrixPtr< dtype > target_y_, mic::types::MatrixPtr< dtype > predicted_y_)
 Calculates cross entropy(using log) and returns cross-entropy error (CE). More...
 
mic::types::MatrixPtr< dtype > calculateGradient (mic::types::MatrixPtr< dtype > target_y_, mic::types::MatrixPtr< dtype > predicted_y_)
 Gradient calculation for cross-entropy. More...
 
- Public Member Functions inherited from mic::neural_nets::loss::Loss< dtype >
virtual dtype calculateMeanLoss (mic::types::MatrixPtr< dtype > target_y_, mic::types::MatrixPtr< dtype > predicted_y_)
 Calculates mean loss (i.e. divides the loss by the size of batch) - ACE for cross-entropy or MSE for regression. More...
 

Detailed Description

template<typename dtype = float>
class mic::neural_nets::loss::CrossEntropyLoss< dtype >

Class representing a cross-entropy loss function (classification).

Author
tkornuta
Template Parameters
dtypeTemplate parameter denoting precision of variables.

Definition at line 43 of file CrossEntropyLoss.hpp.

Member Function Documentation

template<typename dtype = float>
mic::types::MatrixPtr<dtype> mic::neural_nets::loss::CrossEntropyLoss< dtype >::calculateGradient ( mic::types::MatrixPtr< dtype >  target_y_,
mic::types::MatrixPtr< dtype >  predicted_y_ 
)
inlinevirtual

Gradient calculation for cross-entropy.

Implements mic::neural_nets::loss::Loss< dtype >.

Definition at line 68 of file CrossEntropyLoss.hpp.

Referenced by TEST_F().

template<typename dtype = float>
dtype mic::neural_nets::loss::CrossEntropyLoss< dtype >::calculateLoss ( mic::types::MatrixPtr< dtype >  target_y_,
mic::types::MatrixPtr< dtype >  predicted_y_ 
)
inlinevirtual

Calculates cross entropy(using log) and returns cross-entropy error (CE).

Implements mic::neural_nets::loss::Loss< dtype >.

Definition at line 48 of file CrossEntropyLoss.hpp.

Referenced by TEST_F().


The documentation for this class was generated from the following file: