MachineIntelligenceCore:ReinforcementLearning
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator
mic::application::GridworldDeepQLearning Class Reference

Class responsible for solving the gridworld problem with Q-learning and (not that) deep neural networks. More...

#include <GridworldDeepQLearning.hpp>

Inheritance diagram for mic::application::GridworldDeepQLearning:
Collaboration diagram for mic::application::GridworldDeepQLearning:

Public Member Functions

 GridworldDeepQLearning (std::string node_name_="application")
 
virtual ~GridworldDeepQLearning ()
 

Protected Member Functions

virtual void initialize (int argc, char *argv[])
 
virtual void initializePropertyDependentVariables ()
 
virtual bool performSingleStep ()
 
virtual void startNewEpisode ()
 
virtual void finishCurrentEpisode ()
 

Private Member Functions

float computeBestValueForCurrentState ()
 
mic::types::MatrixXfPtr getPredictedRewardsForCurrentState ()
 
mic::types::NESWAction selectBestActionForCurrentState ()
 
std::string streamNetworkResponseTable ()
 

Private Attributes

WindowCollectorChart< float > * w_chart
 Window for displaying statistics. More...
 
mic::utils::DataCollectorPtr
< std::string, float > 
collector_ptr
 Data collector. More...
 
mic::environments::Gridworld grid_env
 The gridworld environment. More...
 
mic::configuration::Property
< float > 
step_reward
 
mic::configuration::Property
< float > 
discount_rate
 
mic::configuration::Property
< float > 
learning_rate
 
mic::configuration::Property
< double > 
epsilon
 
mic::configuration::Property
< std::string > 
statistics_filename
 Property: name of the file to which the statistics will be exported. More...
 
mic::configuration::Property
< std::string > 
mlnn_filename
 Property: name of the file to which the neural network will be serialized (or deserialized from). More...
 
mic::configuration::Property
< bool > 
mlnn_save
 Property: flad denoting thether the nn should be saved to a file (after every episode end). More...
 
mic::configuration::Property
< bool > 
mlnn_load
 Property: flad denoting thether the nn should be loaded from a file (at the initialization of the task). More...
 
BackpropagationNeuralNetwork
< float > 
neural_net
 Multi-layer neural network used for approximation of the Qstate rewards. More...
 
mic::types::Position2D player_pos_t_minus_prim
 
long long sum_of_iterations
 
long long sum_of_rewards
 

Detailed Description

Class responsible for solving the gridworld problem with Q-learning and (not that) deep neural networks.

Author
tkornuta

Definition at line 48 of file GridworldDeepQLearning.hpp.

Constructor & Destructor Documentation

mic::application::GridworldDeepQLearning::GridworldDeepQLearning ( std::string  node_name_ = "application")

Default Constructor. Sets the application/node name, default values of variables, initializes classifier etc.

Parameters
node_name_Name of the application/node (in configuration file).

Definition at line 39 of file GridworldDeepQLearning.cpp.

References discount_rate, epsilon, learning_rate, mlnn_filename, mlnn_load, mlnn_save, statistics_filename, and step_reward.

mic::application::GridworldDeepQLearning::~GridworldDeepQLearning ( )
virtual

Destructor.

Definition at line 63 of file GridworldDeepQLearning.cpp.

References w_chart.

Member Function Documentation

float mic::application::GridworldDeepQLearning::computeBestValueForCurrentState ( )
private

Calculates the best value for the current state - by finding the action having the maximal expected value.

Returns
Value for given state.

Definition at line 215 of file GridworldDeepQLearning.cpp.

References getPredictedRewardsForCurrentState(), grid_env, and mic::environments::Environment::isActionAllowed().

Referenced by performSingleStep().

void mic::application::GridworldDeepQLearning::finishCurrentEpisode ( )
protectedvirtual

Method called when given episode ends (goal: export collected statistics to file etc.) - abstract, to be overridden.

Definition at line 128 of file GridworldDeepQLearning.cpp.

References collector_ptr, mic::environments::Gridworld::getAgentPosition(), mic::environments::Gridworld::getStateReward(), grid_env, mlnn_filename, mlnn_save, neural_net, statistics_filename, sum_of_iterations, and sum_of_rewards.

mic::types::MatrixXfPtr mic::application::GridworldDeepQLearning::getPredictedRewardsForCurrentState ( )
private

Returns the predicted rewards for given state.

Returns
Pointerr to predicted rewards (network output matrix).

Definition at line 244 of file GridworldDeepQLearning.cpp.

References mic::environments::Gridworld::encodeAgentGrid(), grid_env, and neural_net.

Referenced by computeBestValueForCurrentState(), performSingleStep(), and selectBestActionForCurrentState().

void mic::application::GridworldDeepQLearning::initialize ( int  argc,
char *  argv[] 
)
protectedvirtual

Method initializes GLUT and OpenGL windows.

Parameters
argcNumber of application parameters.
argvArray of application parameters.

Definition at line 68 of file GridworldDeepQLearning.cpp.

References collector_ptr, sum_of_iterations, sum_of_rewards, and w_chart.

void mic::application::GridworldDeepQLearning::initializePropertyDependentVariables ( )
protectedvirtual
mic::types::NESWAction mic::application::GridworldDeepQLearning::selectBestActionForCurrentState ( )
private

Finds the best action for the current state.

Returns
The best action found.

Definition at line 254 of file GridworldDeepQLearning.cpp.

References getPredictedRewardsForCurrentState(), grid_env, and mic::environments::Environment::isActionAllowed().

Referenced by performSingleStep().

void mic::application::GridworldDeepQLearning::startNewEpisode ( )
protectedvirtual

Method called at the beginning of new episode (goal: to reset the statistics etc.) - abstract, to be overridden.

Definition at line 116 of file GridworldDeepQLearning.cpp.

References mic::environments::Gridworld::environmentToString(), grid_env, mic::environments::Gridworld::initializeEnvironment(), and streamNetworkResponseTable().

std::string mic::application::GridworldDeepQLearning::streamNetworkResponseTable ( )
private

Member Data Documentation

mic::utils::DataCollectorPtr<std::string, float> mic::application::GridworldDeepQLearning::collector_ptr
private

Data collector.

Definition at line 97 of file GridworldDeepQLearning.hpp.

Referenced by finishCurrentEpisode(), and initialize().

mic::configuration::Property<float> mic::application::GridworldDeepQLearning::discount_rate
private

Property: future discount (should be in range 0.0-1.0).

Definition at line 110 of file GridworldDeepQLearning.hpp.

Referenced by GridworldDeepQLearning(), and performSingleStep().

mic::configuration::Property<double> mic::application::GridworldDeepQLearning::epsilon
private

Property: variable denoting epsilon in action selection (the probability "below" which a random action will be selected). if epsilon < 0 then if will be set to 1/episode, hence change dynamically depending on the episode number.

Definition at line 121 of file GridworldDeepQLearning.hpp.

Referenced by GridworldDeepQLearning(), and performSingleStep().

mic::configuration::Property<float> mic::application::GridworldDeepQLearning::learning_rate
private

Property: neural network learning rate (should be in range 0.0-1.0).

Definition at line 115 of file GridworldDeepQLearning.hpp.

Referenced by GridworldDeepQLearning(), and performSingleStep().

mic::configuration::Property<std::string> mic::application::GridworldDeepQLearning::mlnn_filename
private

Property: name of the file to which the neural network will be serialized (or deserialized from).

Definition at line 127 of file GridworldDeepQLearning.hpp.

Referenced by finishCurrentEpisode(), GridworldDeepQLearning(), and initializePropertyDependentVariables().

mic::configuration::Property<bool> mic::application::GridworldDeepQLearning::mlnn_load
private

Property: flad denoting thether the nn should be loaded from a file (at the initialization of the task).

Definition at line 133 of file GridworldDeepQLearning.hpp.

Referenced by GridworldDeepQLearning(), and initializePropertyDependentVariables().

mic::configuration::Property<bool> mic::application::GridworldDeepQLearning::mlnn_save
private

Property: flad denoting thether the nn should be saved to a file (after every episode end).

Definition at line 130 of file GridworldDeepQLearning.hpp.

Referenced by finishCurrentEpisode(), and GridworldDeepQLearning().

BackpropagationNeuralNetwork<float> mic::application::GridworldDeepQLearning::neural_net
private

Multi-layer neural network used for approximation of the Qstate rewards.

Definition at line 136 of file GridworldDeepQLearning.hpp.

Referenced by finishCurrentEpisode(), getPredictedRewardsForCurrentState(), initializePropertyDependentVariables(), performSingleStep(), and streamNetworkResponseTable().

mic::types::Position2D mic::application::GridworldDeepQLearning::player_pos_t_minus_prim
private

Player position at time (t-1).

Definition at line 165 of file GridworldDeepQLearning.hpp.

Referenced by performSingleStep().

mic::configuration::Property<std::string> mic::application::GridworldDeepQLearning::statistics_filename
private

Property: name of the file to which the statistics will be exported.

Definition at line 124 of file GridworldDeepQLearning.hpp.

Referenced by finishCurrentEpisode(), and GridworldDeepQLearning().

mic::configuration::Property<float> mic::application::GridworldDeepQLearning::step_reward
private

Property: the "expected intermediate reward", i.e. reward received by performing each step (typically negative, but can be positive as all).

Definition at line 105 of file GridworldDeepQLearning.hpp.

Referenced by GridworldDeepQLearning(), and performSingleStep().

long long mic::application::GridworldDeepQLearning::sum_of_iterations
private

Sum of all iterations made till now - used in statistics.

Definition at line 170 of file GridworldDeepQLearning.hpp.

Referenced by finishCurrentEpisode(), and initialize().

long long mic::application::GridworldDeepQLearning::sum_of_rewards
private

Sum of all rewards collected till now - used in statistics.

Definition at line 175 of file GridworldDeepQLearning.hpp.

Referenced by finishCurrentEpisode(), and initialize().

WindowCollectorChart<float>* mic::application::GridworldDeepQLearning::w_chart
private

Window for displaying statistics.

Definition at line 94 of file GridworldDeepQLearning.hpp.

Referenced by initialize(), and ~GridworldDeepQLearning().


The documentation for this class was generated from the following files: