From 8ae89b48da04e0d06041b3ee21f9d6afe661649c Mon Sep 17 00:00:00 2001 From: Salvo Rinzivillo Date: Tue, 3 Dec 2024 18:23:12 +0100 Subject: [PATCH] updating test unit for genetic generator --- test/test_neighgen.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test/test_neighgen.py b/test/test_neighgen.py index a55b32b..b64dd88 100644 --- a/test/test_neighgen.py +++ b/test/test_neighgen.py @@ -2,6 +2,7 @@ import joblib import unittest +import numpy as np import pandas as pd from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split @@ -91,10 +92,21 @@ def test_genetic_generator_generate_balanced(self): z = self.enc.encode([x.values])[0] # remove the class feature from the input instance gen = GeneticGenerator(bbox=self.bbox, dataset=self.dataset, encoder=self.enc, ocr=0.1, ngen=20) - neighbour = gen.generate(z, 1000, self.dataset.descriptor, self.enc) + neighbour = gen.generate(z, 500, self.dataset.descriptor, self.enc) # Assert the lenght of the generated dataset is at least 1000 self.assertGreaterEqual(neighbour.shape[0], 100) + dec_neighbour = self.enc.decode(neighbour) + # checking and filtering the rows in dec_neighbour that contains a None value + # if there is a None value, the row is removed + dec_neighbour = dec_neighbour[~pd.isnull(dec_neighbour).any(axis=1)] + pred_neighbour = self.bbox.predict(dec_neighbour) + classes, count = np.unique(pred_neighbour, return_counts=True) + print('classes', classes) + print('count', count) + self.assertTrue(len(classes) > 1, "The generated dataset should contain at least two classes") + + if __name__ == '__main__':