37#ifndef CLASSIFICATION_LEARNING_AGENT_H
38#define CLASSIFICATION_LEARNING_AGENT_H
45#include "learn/classificationEvaluationResult.h"
46#include "learn/classificationLearningEnvironment.h"
47#include "learn/evaluationResult.h"
48#include "learn/learningAgent.h"
49#include "learn/parallelLearningAgent.h"
73 template <
class BaseLearningAgent = ParallelLearningAgent>
77 std::is_convertible<BaseLearningAgent*, LearningAgent*>::value);
94 : BaseLearningAgent(le, iSet, p, factory){};
104 virtual std::shared_ptr<EvaluationResult>
evaluateJob(
137 std::multimap<std::shared_ptr<EvaluationResult>,
141 template <
class BaseLearningAgent>
142 inline std::shared_ptr<EvaluationResult> ClassificationLearningAgent<
145 uint64_t generationNumber,
155 std::shared_ptr<Learn::EvaluationResult> previousEval;
156 if (mode == LearningMode::TRAINING &&
157 this->isRootEvalSkipped(*root, previousEval)) {
162 std::vector<double> result(this->learningEnvironment.getNbActions(),
164 std::vector<size_t> nbEvalPerClass(
165 this->learningEnvironment.getNbActions(), 0);
168 for (
auto i = 0; i < this->params.nbIterationsPerPolicyEvaluation;
171 Data::Hash<uint64_t> hasher;
172 uint64_t hash = hasher(generationNumber) ^ hasher(i);
175 le.
reset(hash, mode);
177 uint64_t nbActions = 0;
179 nbActions < this->params.maxNbActionsPerEval) {
191 const auto& classificationTable =
193 .getClassificationTable();
195 for (uint64_t classIdx = 0; classIdx < classificationTable.size();
197 uint64_t truePositive =
198 classificationTable.at(classIdx).at(classIdx);
199 uint64_t falseNegative =
200 std::accumulate(classificationTable.at(classIdx).begin(),
201 classificationTable.at(classIdx).end(),
204 uint64_t falsePositive = 0;
206 classificationTable.begin(), classificationTable.end(),
207 [&classIdx, &falsePositive](
208 const std::vector<uint64_t>& classifForClass) {
209 falsePositive += classifForClass.at(classIdx);
211 falsePositive -= truePositive;
213 double recall = (double)truePositive /
214 (
double)(truePositive + falseNegative);
215 double precision = (double)truePositive /
216 (
double)(truePositive + falsePositive);
218 double fScore = (truePositive != 0) ? 2 * (precision * recall) /
221 result.at(classIdx) += fScore;
223 nbEvalPerClass.at(classIdx) += truePositive + falseNegative;
230 std::for_each(result.begin(), result.end(), [p](
double& val) {
231 val /= (double)p.nbIterationsPerPolicyEvaluation;
235 auto evaluationResult = std::shared_ptr<EvaluationResult>(
239 if (previousEval !=
nullptr) {
240 *evaluationResult += *previousEval;
242 return evaluationResult;
245 template <
class BaseLearningAgent>
247 std::multimap<std::shared_ptr<EvaluationResult>,
const TPG::TPGVertex*>&
254 throw std::runtime_error(
255 "ClassificationLearningAgent can not decimate worst roots for "
256 "results whose type is not ClassificationEvaluationResult.");
260 uint64_t totalNbRoot = this->tpg->getNbRootVertices();
261 uint64_t nbRootsToDelete =
262 (uint64_t)floor(this->params.ratioDeletedRoots * totalNbRoot);
263 uint64_t nbRootsToKeep = (totalNbRoot - nbRootsToDelete);
269 uint64_t nbRootsKeptPerClass =
270 (nbRootsToKeep / this->learningEnvironment.getNbActions()) / 2;
271 uint64_t nbRootsKeptGeneralScore =
273 this->learningEnvironment.getNbActions() * nbRootsKeptPerClass;
276 std::vector<const TPG::TPGVertex*> rootsToKeep;
279 for (uint64_t classIdx = 0;
280 classIdx < this->learningEnvironment.getNbActions(); classIdx++) {
283 std::multimap<double, const TPG::TPGVertex*> sortedRoot;
285 results.begin(), results.end(),
287 &classIdx](
const std::pair<std::shared_ptr<EvaluationResult>,
290 ((ClassificationEvaluationResult*)res.first.get())
298 auto iterator = sortedRoot.rbegin();
299 for (
auto i = 0; i < nbRootsKeptPerClass; i++) {
301 if (std::find(rootsToKeep.begin(), rootsToKeep.end(),
302 iterator->second) == rootsToKeep.end()) {
303 rootsToKeep.push_back(iterator->second);
314 auto iterator = results.rbegin();
315 while (rootsToKeep.size() < nbRootsToKeep &&
316 iterator != results.rend()) {
318 if (std::find(rootsToKeep.begin(), rootsToKeep.end(),
319 iterator->second) == rootsToKeep.end()) {
320 rootsToKeep.push_back(iterator->second);
329 auto allRoots = this->tpg->getRootVertices();
330 auto& tpgRef = this->tpg;
331 auto& resultsPerRootRef = this->resultsPerRoot;
333 allRoots.begin(), allRoots.end(),
334 [&rootsToKeep, &tpgRef, &resultsPerRootRef,
337 if (dynamic_cast<const TPG::TPGAction*>(vert) == nullptr &&
338 std::find(rootsToKeep.begin(), rootsToKeep.end(), vert) ==
340 tpgRef->removeVertex(*vert);
343 resultsPerRootRef.erase(vert);
346 std::multimap<std::shared_ptr<EvaluationResult>,
347 const TPG::TPGVertex*>::iterator iter =
349 while (iter != results.end()) {
350 if (iter->second == vert) {
Class for storing a set of Instruction.
Definition: set.h:53
Specialization of the EvaluationResult class for classification LearningEnvironment.
Definition: classificationEvaluationResult.h:53
LearningAgent specialized for LearningEnvironments representing a classification problem.
Definition: classificationLearningAgent.h:75
void decimateWorstRoots(std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &results) override
Specialization of the decimateWorstRoots method for classification purposes.
Definition: classificationLearningAgent.h:246
virtual std::shared_ptr< EvaluationResult > evaluateJob(TPG::TPGExecutionEngine &tee, const Job &root, uint64_t generationNumber, LearningMode mode, LearningEnvironment &le) const override
Specialization of the evaluateJob method for classification purposes.
Definition: classificationLearningAgent.h:143
ClassificationLearningAgent(ClassificationLearningEnvironment &le, const Instructions::Set &iSet, const LearningParameters &p, const TPG::TPGFactory &factory=TPG::TPGFactory())
Constructor for LearningAgent.
Definition: classificationLearningAgent.h:90
Specialization of the LearningEnvironment class for classification purposes.
Definition: classificationLearningEnvironment.h:49
Base class for storing all result of a policy evaluation within a LearningEnvironment.
Definition: evaluationResult.h:52
This class embeds roots for the simulations.
Definition: job.h:53
virtual const TPG::TPGVertex * getRoot() const
Getter of the root.
Definition: job.cpp:49
Interface for creating a Learning Environment.
Definition: learningEnvironment.h:80
virtual void reset(size_t seed=0, LearningMode mode=LearningMode::TRAINING)=0
Reset the LearningEnvironment.
virtual bool isTerminal() const =0
Method for checking if the LearningEnvironment has reached a terminal state.
virtual void doAction(uint64_t actionID)
Execute an action on the LearningEnvironment.
Definition: learningEnvironment.cpp:50
Class representing an Action of a TPGGraph.
Definition: tpgAction.h:52
Definition: tpgExecutionEngine.h:56
virtual const std::vector< const TPGVertex * > executeFromRoot(const TPGVertex &root)
Execute the TPGGraph starting from the given TPGVertex.
Definition: tpgExecutionEngine.cpp:120
Factory for creating all elements constituting a TPG.
Definition: tpgFactory.h:34
Abstract class representing the vertices of a TPGGraph.
Definition: tpgVertex.h:49
Definition: adversarialEvaluationResult.h:45
LearningMode
Different modes in which the LearningEnvironment can be reset.
Definition: learningEnvironment.h:58
Structure for simplifying the transmission of LearningParameters to functions.
Definition: learningParameters.h:53