Skip to content

Commit

Permalink
updating test unit for genetic generator
Browse files Browse the repository at this point in the history
  • Loading branch information
rinziv committed Dec 3, 2024
1 parent 742bc54 commit 8ae89b4
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion test/test_neighgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import joblib
import unittest

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -91,10 +92,21 @@ def test_genetic_generator_generate_balanced(self):
z = self.enc.encode([x.values])[0] # remove the class feature from the input instance

gen = GeneticGenerator(bbox=self.bbox, dataset=self.dataset, encoder=self.enc, ocr=0.1, ngen=20)
neighbour = gen.generate(z, 1000, self.dataset.descriptor, self.enc)
neighbour = gen.generate(z, 500, self.dataset.descriptor, self.enc)
# Assert the lenght of the generated dataset is at least 1000
self.assertGreaterEqual(neighbour.shape[0], 100)

dec_neighbour = self.enc.decode(neighbour)
# checking and filtering the rows in dec_neighbour that contains a None value
# if there is a None value, the row is removed
dec_neighbour = dec_neighbour[~pd.isnull(dec_neighbour).any(axis=1)]
pred_neighbour = self.bbox.predict(dec_neighbour)
classes, count = np.unique(pred_neighbour, return_counts=True)
print('classes', classes)
print('count', count)
self.assertTrue(len(classes) > 1, "The generated dataset should contain at least two classes")




if __name__ == '__main__':
Expand Down

0 comments on commit 8ae89b4

Please sign in to comment.