GEGELATI
Public Member Functions | List of all members
Learn::ClassificationLearningAgent< BaseLearningAgent > Class Template Reference

LearningAgent specialized for LearningEnvironments representing a classification problem. More...

#include <classificationLearningAgent.h>

Inheritance diagram for Learn::ClassificationLearningAgent< BaseLearningAgent >:
Learn::ParallelLearningAgent Learn::LearningAgent

Public Member Functions

 ClassificationLearningAgent (ClassificationLearningEnvironment &le, const Instructions::Set &iSet, const LearningParameters &p, const TPG::TPGFactory &factory=TPG::TPGFactory())
 Constructor for LearningAgent. More...
 
virtual std::shared_ptr< EvaluationResultevaluateJob (TPG::TPGExecutionEngine &tee, const Job &root, uint64_t generationNumber, LearningMode mode, LearningEnvironment &le) const override
 Specialization of the evaluateJob method for classification purposes. More...
 
void decimateWorstRoots (std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &results) override
 Specialization of the decimateWorstRoots method for classification purposes. More...
 
- Public Member Functions inherited from Learn::ParallelLearningAgent
 ParallelLearningAgent (LearningEnvironment &le, const Instructions::Set &iSet, const LearningParameters &p, const TPG::TPGFactory &factory=TPG::TPGFactory())
 Constructor for ParallelLearningAgent. More...
 
std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > evaluateAllRoots (uint64_t generationNumber, LearningMode mode) override
 Evaluate all root TPGVertex of the TPGGraph. More...
 
- Public Member Functions inherited from Learn::LearningAgent
 LearningAgent (LearningEnvironment &le, const Instructions::Set &iSet, const LearningParameters &p, const TPG::TPGFactory &factory=TPG::TPGFactory())
 Constructor for LearningAgent. More...
 
virtual ~LearningAgent ()=default
 Default destructor for polymorphism.
 
std::shared_ptr< TPG::TPGGraphgetTPGGraph ()
 Getter for the TPGGraph built by the LearningAgent. More...
 
const ArchivegetArchive () const
 Getter for the Archive filled by the LearningAgent. More...
 
Mutator::RNGgetRNG ()
 Getter for the RNG used by the LearningAgent. More...
 
void addLogger (Log::LALogger &logger)
 Adds a LALogger to the loggers vector. More...
 
virtual std::shared_ptr< EvaluationResultevaluateJob (TPG::TPGExecutionEngine &tee, const Job &job, uint64_t generationNumber, LearningMode mode, LearningEnvironment &le) const
 Evaluates policy starting from the given root. More...
 
bool isRootEvalSkipped (const TPG::TPGVertex &root, std::shared_ptr< Learn::EvaluationResult > &previousResult) const
 Method detecting whether a root should be evaluated again. More...
 
virtual std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > evaluateAllRoots (uint64_t generationNumber, LearningMode mode)
 Evaluate all root TPGVertex of the TPGGraph. More...
 
virtual void trainOneGeneration (uint64_t generationNumber)
 Train the TPGGraph for one generation. More...
 
virtual void decimateWorstRoots (std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &results)
 Removes from the TPGGraph the root TPGVertex with the worst results. More...
 
uint64_t train (volatile bool &altTraining, bool printProgressBar)
 Train the TPGGraph for a given number of generation. More...
 
void updateEvaluationRecords (const std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &results)
 Update the bestRoot and resultsPerRoot attributes. More...
 
void forgetPreviousResults ()
 This method resets the previous registered scores per root. More...
 
const std::pair< const TPG::TPGVertex *, std::shared_ptr< EvaluationResult > > & getBestRoot () const
 Get the best root TPG::Vertex encountered since the last init. More...
 
void keepBestPolicy ()
 This method keeps only the bestRoot policy in the TPGGraph. More...
 
virtual std::shared_ptr< Learn::JobmakeJob (int num, Learn::LearningMode mode, int idx=0, TPG::TPGGraph *tpgGraph=nullptr)
 Takes a given root index and creates a job containing it. Useful for example in adversarial mode where a job could contain a match of several roots. More...
 
virtual std::queue< std::shared_ptr< Learn::Job > > makeJobs (Learn::LearningMode mode, TPG::TPGGraph *tpgGraph=nullptr)
 Puts all roots into jobs to be able to use them in simulation later. More...
 
void init (uint64_t seed=0)
 Initialize the LearningAgent. More...
 

Additional Inherited Members

- Protected Member Functions inherited from Learn::ParallelLearningAgent
virtual void evaluateAllRootsInParallel (uint64_t generationNumber, LearningMode mode, std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &results)
 Method for evaluating all roots with parallelism. More...
 
virtual void evaluateAllRootsInParallelExecute (uint64_t generationNumber, LearningMode mode, std::map< uint64_t, std::pair< std::shared_ptr< EvaluationResult >, std::shared_ptr< Job > > > &resultsPerJobMap, std::map< uint64_t, Archive * > &archiveMap)
 Subfunction of evaluateAllRootsInParallel which handles the creation of threads, their execution and junction. More...
 
virtual void evaluateAllRootsInParallelCompileResults (std::map< uint64_t, std::pair< std::shared_ptr< EvaluationResult >, std::shared_ptr< Job > > > &resultsPerJobMap, std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &results, std::map< uint64_t, Archive * > &archiveMap)
 Subfunction of evaluateAllRootsInParallel which handles the gathering of results and the merge of the archives. More...
 
void slaveEvalJobThread (uint64_t generationNumber, LearningMode mode, std::queue< std::shared_ptr< Learn::Job > > &jobsToProcess, std::mutex &rootsToProcessMutex, std::map< uint64_t, std::pair< std::shared_ptr< EvaluationResult >, std::shared_ptr< Job > > > &resultsPerRootMap, std::mutex &resultsPerRootMapMutex, std::map< uint64_t, Archive * > &archiveMap, std::mutex &archiveMapMutex, bool useMainEnvironment)
 Function implementing the behavior of slave threads during parallel evaluation of roots. More...
 
void mergeArchiveMap (std::map< uint64_t, Archive * > &archiveMap)
 Method to merge several Archive created in parallel threads. More...
 
- Protected Attributes inherited from Learn::LearningAgent
LearningEnvironmentlearningEnvironment
 LearningEnvironment with which the LearningAgent will interact.
 
Environment env
 Environment for executing Program of the LearningAgent.
 
Archive archive
 Archive used during the training process.
 
LearningParameters params
 Parameters for the learning process.
 
std::shared_ptr< TPG::TPGGraphtpg
 TPGGraph built during the learning process.
 
std::pair< const TPG::TPGVertex *, std::shared_ptr< EvaluationResult > > bestRoot {nullptr, nullptr}
 
std::map< const TPG::TPGVertex *, std::shared_ptr< EvaluationResult > > resultsPerRoot
 Map associating root TPG::TPGVertex to their EvaluationResult. More...
 
Mutator::RNG rng
 Random Number Generator for this Learning Agent.
 
uint64_t maxNbThreads = 1
 Control the maximum number of threads when running in parallel.
 
std::vector< std::reference_wrapper< Log::LALogger > > loggers
 Set of LALogger called throughout the training process. More...
 

Detailed Description

template<class BaseLearningAgent = ParallelLearningAgent>
class Learn::ClassificationLearningAgent< BaseLearningAgent >

LearningAgent specialized for LearningEnvironments representing a classification problem.

The key difference between this ClassificationLearningAgent and the base LearningAgent is the way roots are selected for decimation after each generation. In this agent, the roots are decimated based on an average score per class instead of decimating roots based on their global average score (over all classes) during the last evaluation. By doing so, the roots providing the best score in each class are preserved which increases the chances of correct classifiers emergence for all classes.

In this context, it is assumed that each action of the LearningEnvironment represents a class of the classification problem.

The BaseLearningAgent template parameter is the LearningAgent from which the ClassificationLearningAgent inherits. This template notably enable selecting between the classical and the ParallelLearningAgent.

Constructor & Destructor Documentation

◆ ClassificationLearningAgent()

template<class BaseLearningAgent = ParallelLearningAgent>
Learn::ClassificationLearningAgent< BaseLearningAgent >::ClassificationLearningAgent ( ClassificationLearningEnvironment le,
const Instructions::Set iSet,
const LearningParameters p,
const TPG::TPGFactory factory = TPG::TPGFactory() 
)
inline

Constructor for LearningAgent.

Parameters
[in]leThe LearningEnvironment for the TPG.
[in]iSetSet of Instruction used to compose Programs in the learning process.
[in]pThe LearningParameters for the LearningAgent.
[in]factoryThe TPGFactory used to create the TPGGraph. A default TPGFactory is used if none is provided.

Member Function Documentation

◆ decimateWorstRoots()

template<class BaseLearningAgent >
void Learn::ClassificationLearningAgent< BaseLearningAgent >::decimateWorstRoots ( std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &  results)
overridevirtual

Specialization of the decimateWorstRoots method for classification purposes.

During the decimation process, roughly half of the roots are kept based on their score for individual class of the ClassificationLearningEnvironment. To do so, for each class of the ClassificationLearningEnvironment, the roots provided the best score are preserved during the decimation process even if their global score over all classes is not among the best.

The remaining half of preserved roots is selected using the general score obtained over all classes.

This per-class preservation is activated only if there is a sufficient number of root vertices in the TPGGraph after decimation to guarantee that all classes are preserved equally. In other word, the same number of root is marked for preservation for each class, which can only be achieved if the number of roots to preserve during the decimation process is superior or equal to twice the number of actions of the ClassificationLearningEnvironment. If an insufficient number of root is preserved during the decimation process, all roots are preserved based on their general score.

The results map is updated by the method to keep only the results of non-decimated roots.

Reimplemented from Learn::LearningAgent.

◆ evaluateJob()

template<class BaseLearningAgent >
std::shared_ptr< EvaluationResult > Learn::ClassificationLearningAgent< BaseLearningAgent >::evaluateJob ( TPG::TPGExecutionEngine tee,
const Job root,
uint64_t  generationNumber,
LearningMode  mode,
LearningEnvironment le 
) const
inlineoverridevirtual

Specialization of the evaluateJob method for classification purposes.

This method returns a ClassificationEvaluationResult for the evaluated root instead of the usual EvaluationResult. The score per root corresponds to the F1 score for this class.

Reimplemented from Learn::LearningAgent.


The documentation for this class was generated from the following file: