GEGELATI
|
LearningAgent specialized for LearningEnvironments representing a classification problem. More...
#include <classificationLearningAgent.h>
Public Member Functions | |
ClassificationLearningAgent (ClassificationLearningEnvironment &le, const Instructions::Set &iSet, const LearningParameters &p, const TPG::TPGFactory &factory=TPG::TPGFactory()) | |
Constructor for LearningAgent. More... | |
virtual std::shared_ptr< EvaluationResult > | evaluateJob (TPG::TPGExecutionEngine &tee, const Job &root, uint64_t generationNumber, LearningMode mode, LearningEnvironment &le) const override |
Specialization of the evaluateJob method for classification purposes. More... | |
void | decimateWorstRoots (std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &results) override |
Specialization of the decimateWorstRoots method for classification purposes. More... | |
![]() | |
ParallelLearningAgent (LearningEnvironment &le, const Instructions::Set &iSet, const LearningParameters &p, const TPG::TPGFactory &factory=TPG::TPGFactory()) | |
Constructor for ParallelLearningAgent. More... | |
std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > | evaluateAllRoots (uint64_t generationNumber, LearningMode mode) override |
Evaluate all root TPGVertex of the TPGGraph. More... | |
![]() | |
LearningAgent (LearningEnvironment &le, const Instructions::Set &iSet, const LearningParameters &p, const TPG::TPGFactory &factory=TPG::TPGFactory()) | |
Constructor for LearningAgent. More... | |
virtual | ~LearningAgent ()=default |
Default destructor for polymorphism. | |
std::shared_ptr< TPG::TPGGraph > | getTPGGraph () |
Getter for the TPGGraph built by the LearningAgent. More... | |
const Archive & | getArchive () const |
Getter for the Archive filled by the LearningAgent. More... | |
Mutator::RNG & | getRNG () |
Getter for the RNG used by the LearningAgent. More... | |
void | addLogger (Log::LALogger &logger) |
Adds a LALogger to the loggers vector. More... | |
virtual std::shared_ptr< EvaluationResult > | evaluateJob (TPG::TPGExecutionEngine &tee, const Job &job, uint64_t generationNumber, LearningMode mode, LearningEnvironment &le) const |
Evaluates policy starting from the given root. More... | |
bool | isRootEvalSkipped (const TPG::TPGVertex &root, std::shared_ptr< Learn::EvaluationResult > &previousResult) const |
Method detecting whether a root should be evaluated again. More... | |
virtual std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > | evaluateAllRoots (uint64_t generationNumber, LearningMode mode) |
Evaluate all root TPGVertex of the TPGGraph. More... | |
virtual void | trainOneGeneration (uint64_t generationNumber) |
Train the TPGGraph for one generation. More... | |
virtual void | decimateWorstRoots (std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &results) |
Removes from the TPGGraph the root TPGVertex with the worst results. More... | |
uint64_t | train (volatile bool &altTraining, bool printProgressBar) |
Train the TPGGraph for a given number of generation. More... | |
void | updateEvaluationRecords (const std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &results) |
Update the bestRoot and resultsPerRoot attributes. More... | |
void | forgetPreviousResults () |
This method resets the previous registered scores per root. More... | |
const std::pair< const TPG::TPGVertex *, std::shared_ptr< EvaluationResult > > & | getBestRoot () const |
Get the best root TPG::Vertex encountered since the last init. More... | |
void | keepBestPolicy () |
This method keeps only the bestRoot policy in the TPGGraph. More... | |
virtual std::shared_ptr< Learn::Job > | makeJob (int num, Learn::LearningMode mode, int idx=0, TPG::TPGGraph *tpgGraph=nullptr) |
Takes a given root index and creates a job containing it. Useful for example in adversarial mode where a job could contain a match of several roots. More... | |
virtual std::queue< std::shared_ptr< Learn::Job > > | makeJobs (Learn::LearningMode mode, TPG::TPGGraph *tpgGraph=nullptr) |
Puts all roots into jobs to be able to use them in simulation later. More... | |
void | init (uint64_t seed=0) |
Initialize the LearningAgent. More... | |
Additional Inherited Members | |
![]() | |
virtual void | evaluateAllRootsInParallel (uint64_t generationNumber, LearningMode mode, std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &results) |
Method for evaluating all roots with parallelism. More... | |
virtual void | evaluateAllRootsInParallelExecute (uint64_t generationNumber, LearningMode mode, std::map< uint64_t, std::pair< std::shared_ptr< EvaluationResult >, std::shared_ptr< Job > > > &resultsPerJobMap, std::map< uint64_t, Archive * > &archiveMap) |
Subfunction of evaluateAllRootsInParallel which handles the creation of threads, their execution and junction. More... | |
virtual void | evaluateAllRootsInParallelCompileResults (std::map< uint64_t, std::pair< std::shared_ptr< EvaluationResult >, std::shared_ptr< Job > > > &resultsPerJobMap, std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &results, std::map< uint64_t, Archive * > &archiveMap) |
Subfunction of evaluateAllRootsInParallel which handles the gathering of results and the merge of the archives. More... | |
void | slaveEvalJobThread (uint64_t generationNumber, LearningMode mode, std::queue< std::shared_ptr< Learn::Job > > &jobsToProcess, std::mutex &rootsToProcessMutex, std::map< uint64_t, std::pair< std::shared_ptr< EvaluationResult >, std::shared_ptr< Job > > > &resultsPerRootMap, std::mutex &resultsPerRootMapMutex, std::map< uint64_t, Archive * > &archiveMap, std::mutex &archiveMapMutex, bool useMainEnvironment) |
Function implementing the behavior of slave threads during parallel evaluation of roots. More... | |
void | mergeArchiveMap (std::map< uint64_t, Archive * > &archiveMap) |
Method to merge several Archive created in parallel threads. More... | |
![]() | |
LearningEnvironment & | learningEnvironment |
LearningEnvironment with which the LearningAgent will interact. | |
Environment | env |
Environment for executing Program of the LearningAgent. | |
Archive | archive |
Archive used during the training process. | |
LearningParameters | params |
Parameters for the learning process. | |
std::shared_ptr< TPG::TPGGraph > | tpg |
TPGGraph built during the learning process. | |
std::pair< const TPG::TPGVertex *, std::shared_ptr< EvaluationResult > > | bestRoot {nullptr, nullptr} |
std::map< const TPG::TPGVertex *, std::shared_ptr< EvaluationResult > > | resultsPerRoot |
Map associating root TPG::TPGVertex to their EvaluationResult. More... | |
Mutator::RNG | rng |
Random Number Generator for this Learning Agent. | |
uint64_t | maxNbThreads = 1 |
Control the maximum number of threads when running in parallel. | |
std::vector< std::reference_wrapper< Log::LALogger > > | loggers |
Set of LALogger called throughout the training process. More... | |
LearningAgent specialized for LearningEnvironments representing a classification problem.
The key difference between this ClassificationLearningAgent and the base LearningAgent is the way roots are selected for decimation after each generation. In this agent, the roots are decimated based on an average score per class instead of decimating roots based on their global average score (over all classes) during the last evaluation. By doing so, the roots providing the best score in each class are preserved which increases the chances of correct classifiers emergence for all classes.
In this context, it is assumed that each action of the LearningEnvironment represents a class of the classification problem.
The BaseLearningAgent template parameter is the LearningAgent from which the ClassificationLearningAgent inherits. This template notably enable selecting between the classical and the ParallelLearningAgent.
|
inline |
Constructor for LearningAgent.
[in] | le | The LearningEnvironment for the TPG. |
[in] | iSet | Set of Instruction used to compose Programs in the learning process. |
[in] | p | The LearningParameters for the LearningAgent. |
[in] | factory | The TPGFactory used to create the TPGGraph. A default TPGFactory is used if none is provided. |
|
overridevirtual |
Specialization of the decimateWorstRoots method for classification purposes.
During the decimation process, roughly half of the roots are kept based on their score for individual class of the ClassificationLearningEnvironment. To do so, for each class of the ClassificationLearningEnvironment, the roots provided the best score are preserved during the decimation process even if their global score over all classes is not among the best.
The remaining half of preserved roots is selected using the general score obtained over all classes.
This per-class preservation is activated only if there is a sufficient number of root vertices in the TPGGraph after decimation to guarantee that all classes are preserved equally. In other word, the same number of root is marked for preservation for each class, which can only be achieved if the number of roots to preserve during the decimation process is superior or equal to twice the number of actions of the ClassificationLearningEnvironment. If an insufficient number of root is preserved during the decimation process, all roots are preserved based on their general score.
The results map is updated by the method to keep only the results of non-decimated roots.
Reimplemented from Learn::LearningAgent.
|
inlineoverridevirtual |
Specialization of the evaluateJob method for classification purposes.
This method returns a ClassificationEvaluationResult for the evaluated root instead of the usual EvaluationResult. The score per root corresponds to the F1 score for this class.
Reimplemented from Learn::LearningAgent.