diff --git a/test/test_neighgen.py b/test/test_neighgen.py index a55b32b..b64dd88 100644 --- a/test/test_neighgen.py +++ b/test/test_neighgen.py @@ -2,6 +2,7 @@ import joblib import unittest +import numpy as np import pandas as pd from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split @@ -91,10 +92,21 @@ def test_genetic_generator_generate_balanced(self): z = self.enc.encode([x.values])[0] # remove the class feature from the input instance gen = GeneticGenerator(bbox=self.bbox, dataset=self.dataset, encoder=self.enc, ocr=0.1, ngen=20) - neighbour = gen.generate(z, 1000, self.dataset.descriptor, self.enc) + neighbour = gen.generate(z, 500, self.dataset.descriptor, self.enc) # Assert the lenght of the generated dataset is at least 1000 self.assertGreaterEqual(neighbour.shape[0], 100) + dec_neighbour = self.enc.decode(neighbour) + # checking and filtering the rows in dec_neighbour that contains a None value + # if there is a None value, the row is removed + dec_neighbour = dec_neighbour[~pd.isnull(dec_neighbour).any(axis=1)] + pred_neighbour = self.bbox.predict(dec_neighbour) + classes, count = np.unique(pred_neighbour, return_counts=True) + print('classes', classes) + print('count', count) + self.assertTrue(len(classes) > 1, "The generated dataset should contain at least two classes") + + if __name__ == '__main__':