GEGELATI
classificationLearningAgent.h
1
37#ifndef CLASSIFICATION_LEARNING_AGENT_H
38#define CLASSIFICATION_LEARNING_AGENT_H
39
40#include <numeric>
41#include <stdexcept>
42#include <type_traits>
43#include <vector>
44
45#include "learn/classificationEvaluationResult.h"
46#include "learn/classificationLearningEnvironment.h"
47#include "learn/evaluationResult.h"
48#include "learn/learningAgent.h"
49#include "learn/parallelLearningAgent.h"
50#include <data/hash.h>
51
52namespace Learn {
73 template <class BaseLearningAgent = ParallelLearningAgent>
74 class ClassificationLearningAgent : public BaseLearningAgent
75 {
76 static_assert(
77 std::is_convertible<BaseLearningAgent*, LearningAgent*>::value);
78
79 public:
92 const Instructions::Set& iSet, const LearningParameters& p,
93 const TPG::TPGFactory& factory = TPG::TPGFactory())
94 : BaseLearningAgent(le, iSet, p, factory){};
95
104 virtual std::shared_ptr<EvaluationResult> evaluateJob(
105 TPG::TPGExecutionEngine& tee, const Job& root,
106 uint64_t generationNumber, LearningMode mode,
107 LearningEnvironment& le) const override;
108
137 std::multimap<std::shared_ptr<EvaluationResult>,
138 const TPG::TPGVertex*>& results) override;
139 };
140
141 template <class BaseLearningAgent>
142 inline std::shared_ptr<EvaluationResult> ClassificationLearningAgent<
143 BaseLearningAgent>::evaluateJob(TPG::TPGExecutionEngine& tee,
144 const Job& job,
145 uint64_t generationNumber,
146 LearningMode mode,
147 LearningEnvironment& le) const
148 {
149 // Only consider the first root of jobs as we are not in adversarial
150 // mode
151 const TPG::TPGVertex* root = job.getRoot();
152
153 // Skip the root evaluation process if enough evaluations were already
154 // performed. In the evaluation mode only.
155 std::shared_ptr<Learn::EvaluationResult> previousEval;
156 if (mode == LearningMode::TRAINING &&
157 this->isRootEvalSkipped(*root, previousEval)) {
158 return previousEval;
159 }
160
161 // Init results
162 std::vector<double> result(this->learningEnvironment.getNbActions(),
163 0.0);
164 std::vector<size_t> nbEvalPerClass(
165 this->learningEnvironment.getNbActions(), 0);
166
167 // Evaluate nbIteration times
168 for (auto i = 0; i < this->params.nbIterationsPerPolicyEvaluation;
169 i++) {
170 // Compute a Hash
171 Data::Hash<uint64_t> hasher;
172 uint64_t hash = hasher(generationNumber) ^ hasher(i);
173
174 // Reset the learning Environment
175 le.reset(hash, mode);
176
177 uint64_t nbActions = 0;
178 while (!le.isTerminal() &&
179 nbActions < this->params.maxNbActionsPerEval) {
180 // Get the action
181 uint64_t actionID =
182 ((const TPG::TPGAction*)tee.executeFromRoot(*root).back())
183 ->getActionID();
184 // Do it
185 le.doAction(actionID);
186 // Count actions
187 nbActions++;
188 }
189
190 // Update results
191 const auto& classificationTable =
193 .getClassificationTable();
194 // for each class
195 for (uint64_t classIdx = 0; classIdx < classificationTable.size();
196 classIdx++) {
197 uint64_t truePositive =
198 classificationTable.at(classIdx).at(classIdx);
199 uint64_t falseNegative =
200 std::accumulate(classificationTable.at(classIdx).begin(),
201 classificationTable.at(classIdx).end(),
202 (uint64_t)0) -
203 truePositive;
204 uint64_t falsePositive = 0;
205 std::for_each(
206 classificationTable.begin(), classificationTable.end(),
207 [&classIdx, &falsePositive](
208 const std::vector<uint64_t>& classifForClass) {
209 falsePositive += classifForClass.at(classIdx);
210 });
211 falsePositive -= truePositive;
212
213 double recall = (double)truePositive /
214 (double)(truePositive + falseNegative);
215 double precision = (double)truePositive /
216 (double)(truePositive + falsePositive);
217 // If true positive is 0, set score to 0.
218 double fScore = (truePositive != 0) ? 2 * (precision * recall) /
219 (precision + recall)
220 : 0.0;
221 result.at(classIdx) += fScore;
222
223 nbEvalPerClass.at(classIdx) += truePositive + falseNegative;
224 }
225 }
226
227 // Before returning the EvaluationResult, divide the result per class by
228 // the number of iteration
229 const LearningParameters& p = this->params;
230 std::for_each(result.begin(), result.end(), [p](double& val) {
231 val /= (double)p.nbIterationsPerPolicyEvaluation;
232 });
233
234 // Create the EvaluationResult
235 auto evaluationResult = std::shared_ptr<EvaluationResult>(
236 new ClassificationEvaluationResult(result, nbEvalPerClass));
237
238 // Combine it with previous one if any
239 if (previousEval != nullptr) {
240 *evaluationResult += *previousEval;
241 }
242 return evaluationResult;
243 }
244
245 template <class BaseLearningAgent>
247 std::multimap<std::shared_ptr<EvaluationResult>, const TPG::TPGVertex*>&
248 results)
249 {
250 // Check that results are ClassificationEvaluationResults.
251 // (also throws on empty results)
252 const EvaluationResult* result = results.begin()->first.get();
253 if (typeid(ClassificationEvaluationResult) != typeid(*result)) {
254 throw std::runtime_error(
255 "ClassificationLearningAgent can not decimate worst roots for "
256 "results whose type is not ClassificationEvaluationResult.");
257 }
258
259 // Compute the number of root to keep/delete base on each criterion
260 uint64_t totalNbRoot = this->tpg->getNbRootVertices();
261 uint64_t nbRootsToDelete =
262 (uint64_t)floor(this->params.ratioDeletedRoots * totalNbRoot);
263 uint64_t nbRootsToKeep = (totalNbRoot - nbRootsToDelete);
264
265 // Keep ~half+ of the roots based on their general score on
266 // all class.
267 // and ~half- of the roots on a per class score (none if nbRoots to keep
268 // < 2*nb class)
269 uint64_t nbRootsKeptPerClass =
270 (nbRootsToKeep / this->learningEnvironment.getNbActions()) / 2;
271 uint64_t nbRootsKeptGeneralScore =
272 nbRootsToKeep -
273 this->learningEnvironment.getNbActions() * nbRootsKeptPerClass;
274
275 // Build a list of roots to keep
276 std::vector<const TPG::TPGVertex*> rootsToKeep;
277
278 // Insert roots to keep per class
279 for (uint64_t classIdx = 0;
280 classIdx < this->learningEnvironment.getNbActions(); classIdx++) {
281 // Fill a map with the roots and the score of the specific class as
282 // ID.
283 std::multimap<double, const TPG::TPGVertex*> sortedRoot;
284 std::for_each(
285 results.begin(), results.end(),
286 [&sortedRoot,
287 &classIdx](const std::pair<std::shared_ptr<EvaluationResult>,
288 const TPG::TPGVertex*>& res) {
289 sortedRoot.emplace(
290 ((ClassificationEvaluationResult*)res.first.get())
291 ->getScorePerClass()
292 .at(classIdx),
293 res.second);
294 });
295
296 // Keep the best nbRootsKeptPerClass (or less for reasons explained
297 // in the loop)
298 auto iterator = sortedRoot.rbegin();
299 for (auto i = 0; i < nbRootsKeptPerClass; i++) {
300 // If the root is not already marked to be kept
301 if (std::find(rootsToKeep.begin(), rootsToKeep.end(),
302 iterator->second) == rootsToKeep.end()) {
303 rootsToKeep.push_back(iterator->second);
304 }
305 // Advance the iterator no matter what.
306 // This means that if a root scores well for several classes
307 // it is kept only once anyway, but additional roots will not
308 // be kept for any of the concerned class.
309 iterator++;
310 }
311 }
312
313 // Insert remaining roots to keep
314 auto iterator = results.rbegin();
315 while (rootsToKeep.size() < nbRootsToKeep &&
316 iterator != results.rend()) {
317 // If the root is not already marked to be kept
318 if (std::find(rootsToKeep.begin(), rootsToKeep.end(),
319 iterator->second) == rootsToKeep.end()) {
320 rootsToKeep.push_back(iterator->second);
321 }
322 // Advance the iterator no matter what.
323 iterator++;
324 }
325
326 // Do the removal.
327 // Because of potential root actions, the preserved number of roots
328 // may be higher than the given ratio.
329 auto allRoots = this->tpg->getRootVertices();
330 auto& tpgRef = this->tpg;
331 auto& resultsPerRootRef = this->resultsPerRoot;
332 std::for_each(
333 allRoots.begin(), allRoots.end(),
334 [&rootsToKeep, &tpgRef, &resultsPerRootRef,
335 &results](const TPG::TPGVertex* vert) {
336 // Do not remove actions
337 if (dynamic_cast<const TPG::TPGAction*>(vert) == nullptr &&
338 std::find(rootsToKeep.begin(), rootsToKeep.end(), vert) ==
339 rootsToKeep.end()) {
340 tpgRef->removeVertex(*vert);
341
342 // Keep only results of non-decimated roots.
343 resultsPerRootRef.erase(vert);
344
345 // Update results also
346 std::multimap<std::shared_ptr<EvaluationResult>,
347 const TPG::TPGVertex*>::iterator iter =
348 results.begin();
349 while (iter != results.end()) {
350 if (iter->second == vert) {
351 results.erase(iter);
352 break;
353 }
354 iter++;
355 }
356 }
357 });
358 }
359}; // namespace Learn
360
361#endif
Class for storing a set of Instruction.
Definition: set.h:53
Specialization of the EvaluationResult class for classification LearningEnvironment.
Definition: classificationEvaluationResult.h:53
LearningAgent specialized for LearningEnvironments representing a classification problem.
Definition: classificationLearningAgent.h:75
void decimateWorstRoots(std::multimap< std::shared_ptr< EvaluationResult >, const TPG::TPGVertex * > &results) override
Specialization of the decimateWorstRoots method for classification purposes.
Definition: classificationLearningAgent.h:246
virtual std::shared_ptr< EvaluationResult > evaluateJob(TPG::TPGExecutionEngine &tee, const Job &root, uint64_t generationNumber, LearningMode mode, LearningEnvironment &le) const override
Specialization of the evaluateJob method for classification purposes.
Definition: classificationLearningAgent.h:143
ClassificationLearningAgent(ClassificationLearningEnvironment &le, const Instructions::Set &iSet, const LearningParameters &p, const TPG::TPGFactory &factory=TPG::TPGFactory())
Constructor for LearningAgent.
Definition: classificationLearningAgent.h:90
Specialization of the LearningEnvironment class for classification purposes.
Definition: classificationLearningEnvironment.h:49
Base class for storing all result of a policy evaluation within a LearningEnvironment.
Definition: evaluationResult.h:52
This class embeds roots for the simulations.
Definition: job.h:53
virtual const TPG::TPGVertex * getRoot() const
Getter of the root.
Definition: job.cpp:49
Interface for creating a Learning Environment.
Definition: learningEnvironment.h:80
virtual void reset(size_t seed=0, LearningMode mode=LearningMode::TRAINING)=0
Reset the LearningEnvironment.
virtual bool isTerminal() const =0
Method for checking if the LearningEnvironment has reached a terminal state.
virtual void doAction(uint64_t actionID)
Execute an action on the LearningEnvironment.
Definition: learningEnvironment.cpp:50
Class representing an Action of a TPGGraph.
Definition: tpgAction.h:52
Definition: tpgExecutionEngine.h:56
virtual const std::vector< const TPGVertex * > executeFromRoot(const TPGVertex &root)
Execute the TPGGraph starting from the given TPGVertex.
Definition: tpgExecutionEngine.cpp:120
Factory for creating all elements constituting a TPG.
Definition: tpgFactory.h:34
Abstract class representing the vertices of a TPGGraph.
Definition: tpgVertex.h:49
Definition: adversarialEvaluationResult.h:45
LearningMode
Different modes in which the LearningEnvironment can be reset.
Definition: learningEnvironment.h:58
Structure for simplifying the transmission of LearningParameters to functions.
Definition: learningParameters.h:53