-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmain.py
executable file
·98 lines (89 loc) · 4.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from models.Others.LR import model as LR
from models.Others.NB import model as NB
from models.TextCNN import model as TextCNN
from models.TextRNN import model as TextRNN
from models.CRNN import model as CRNN
from models.RCNN import model as RCNN
from models.HAN import model as HAN
from models.Ensembles.bagging import model as bagging
from models.Ensembles.stacking import model as stacking
from configs import general_config
def main():
# # model=LR()
# model=NB()
# model.train()
# model.test()
# # for model_type in ["baseline","static","nonstatic","multichannel"]:
# # model=TextCNN(model_type=model_type)
# # model.fit(with_validation=True)
# # for model_type in ["baseline", "static", "nonstatic", "multichannel"]:
# # model = TextCNN(model_type=model_type)
# # model.evaluate(load_path="checkpoints/TextCNN/"+model_type+"/train_valid",
# # validFile=general_config.train_file,
# # vocab2intPath=general_config.local_nonstatic_v2i_path)
# # model.evaluate(load_path="checkpoints/TextCNN/"+model_type+"/train_valid",
# # validFile=general_config.valid_file,
# # vocab2intPath=general_config.local_nonstatic_v2i_path)
# model=TextCNN(model_type="nonstatic")
# # model.fit(with_validation=False,num_epochs=130,num_visual=0)
# model.evaluate(load_path="checkpoints/TextCNN/nonstatic/train")
# model.predict(load_path="checkpoints/TextCNN/nonstatic/train")
#
# model=TextRNN()
# # # model.fit(with_validation=True)
# # model.evaluate(load_path="checkpoints/TextRNN/train_valid",
# # validFile=general_config.train_file,
# # vocab2intPath=general_config.local_nonstatic_v2i_path)
# # model.evaluate(load_path="checkpoints/TextRNN/train_valid",
# # validFile=general_config.valid_file,
# # vocab2intPath=general_config.local_nonstatic_v2i_path)
# # model.fit(with_validation=False, num_epochs=150, num_visual=0)
# model.evaluate(load_path="checkpoints/TextRNN/train")
# model.predict(load_path="checkpoints/TextRNN/train")
#
# model=CRNN()
# # # model.fit(with_validation=True)
# # model.evaluate(load_path="checkpoints/CRNN/train_valid",
# # validFile=general_config.train_file,
# # vocab2intPath=general_config.local_nonstatic_v2i_path)
# # model.evaluate(load_path="checkpoints/CRNN/train_valid",
# # validFile=general_config.valid_file,
# # vocab2intPath=general_config.local_nonstatic_v2i_path)
# # model.fit(with_validation=False, num_epochs=70, num_visual=0)
# model.evaluate(load_path="checkpoints/CRNN/train")
# model.predict(load_path="checkpoints/CRNN/train")
#
# model=RCNN()
# # # model.fit(with_validation=True)
# # model.evaluate(load_path="checkpoints/RCNN/train_valid",
# # validFile=general_config.train_file,
# # vocab2intPath=general_config.local_nonstatic_v2i_path)
# # model.evaluate(load_path="checkpoints/RCNN/train_valid",
# # validFile=general_config.valid_file,
# # vocab2intPath=general_config.local_nonstatic_v2i_path)
# # model.fit(with_validation=False, num_epochs=50, num_visual=0)
# model.evaluate(load_path="checkpoints/RCNN/train")
# model.predict(load_path="checkpoints/RCNN/train")
#
# model=HAN()
# # # model.fit(with_validation=True)
# # model.evaluate(load_path="checkpoints/HAN/train_valid",
# # validFile=general_config.train_file,
# # vocab2intPath=general_config.local_nonstatic_v2i_path)
# # model.evaluate(load_path="checkpoints/HAN/train_valid",
# # validFile=general_config.valid_file,
# # vocab2intPath=general_config.local_nonstatic_v2i_path)
# # model.fit(with_validation=False, num_epochs=110, num_visual=0)
# model.evaluate(load_path="checkpoints/HAN/train")
# model.predict(load_path="checkpoints/HAN/train")
# model=bagging()
# # model.fit()
# model.evaluate()
# model.predict()
model=stacking()
# model.train_1()
# model.train_2()
model.evaluate()
model.predict()
if __name__=="__main__":
main()