forked from despoisj/DeepAudioClassification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
78 lines (57 loc) · 2.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# -*- coding: utf-8 -*-
import random
import string
import os
import sys
import numpy as np
from model import createModel
from datasetTools import getDataset
from config import slicesPath
from config import batchSize
from config import filesPerGenre
from config import nbEpoch
from config import validationRatio, testRatio
from config import sliceSize
from songToData import createSlicesFromAudio
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("mode", help="Trains or tests the CNN", nargs='+', choices=["train","test","slice"])
args = parser.parse_args()
print("--------------------------")
print("| ** Config ** ")
print("| Validation ratio: {}".format(validationRatio))
print("| Test ratio: {}".format(testRatio))
print("| Slices per genre: {}".format(filesPerGenre))
print("| Slice size: {}".format(sliceSize))
print("--------------------------")
if "slice" in args.mode:
createSlicesFromAudio()
sys.exit()
#List genres
genres = os.listdir(slicesPath)
genres = [filename for filename in genres if os.path.isdir(slicesPath+filename)]
nbClasses = len(genres)
#Create model
model = createModel(nbClasses, sliceSize)
if "train" in args.mode:
#Create or load new dataset
train_X, train_y, validation_X, validation_y = getDataset(filesPerGenre, genres, sliceSize, validationRatio, testRatio, mode="train")
#Define run id for graphs
run_id = "MusicGenres - "+str(batchSize)+" "+''.join(random.SystemRandom().choice(string.ascii_uppercase) for _ in range(10))
#Train the model
print("[+] Training the model...")
model.fit(train_X, train_y, n_epoch=nbEpoch, batch_size=batchSize, shuffle=True, validation_set=(validation_X, validation_y), snapshot_step=100, show_metric=True, run_id=run_id)
print(" Model trained! ✅")
#Save trained model
print("[+] Saving the weights...")
model.save('musicDNN.tflearn')
print("[+] Weights saved! ✅💾")
if "test" in args.mode:
#Create or load new dataset
test_X, test_y = getDataset(filesPerGenre, genres, sliceSize, validationRatio, testRatio, mode="test")
#Load weights
print("[+] Loading weights...")
model.load('musicDNN.tflearn')
print(" Weights loaded! ✅")
testAccuracy = model.evaluate(test_X, test_y)[0]
print("[+] Test accuracy: {} ".format(testAccuracy))