-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNetTIME_CRF_predict.py
117 lines (105 loc) · 2.66 KB
/
NetTIME_CRF_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
from NetTIME import CRFPredictWorkflow
######## User Input ########
parser = argparse.ArgumentParser(
"Make predictions using NetTIME linear chain CRF classifier."
)
# Prediction parameters
parser.add_argument(
"--batch_size",
type=int,
default=2700,
help="Prediction batch size. Default: 2700",
)
parser.add_argument(
"--num_workers",
type=int,
default=10,
help="Number of workers used to perform multi-process data loading. "
"Default: 10",
)
parser.add_argument(
"--seed", type=int, default=1111, help="Random seed. Default: 1111"
)
parser.add_argument(
"--model_config",
type=str,
default=None,
help="Specify an alternative path to CRF .config file.",
)
parser.add_argument(
"--best_ckpt",
type=str,
default=None,
help="Specify an alternative path to a best model checkpoint .ckpt file or "
"a best checkpoint evaluation .json file, which will be used to make "
"predictions.",
)
# Data
parser.add_argument(
"--prediction_dir",
type=str,
default=None,
help="Path to NetTIME prediction directory.",
)
parser.add_argument(
"--dtype",
type=str,
default="TEST",
help="Dataset type. Default: TEST.",
)
parser.add_argument(
"--class_weight",
type=str,
default=None,
help="Path to a numpy .npy file specifying the class weight. "
"Default: None, use class weight generated from training data.",
)
# Save
parser.add_argument(
"--output_dir",
type=str,
default="experiments/",
help="Root directory for saving experiment results. Default: experiments/",
)
parser.add_argument(
"--experiment_name",
type=str,
default="training_example",
help="experiment name.",
)
parser.add_argument(
"--result_dir",
type=str,
default=None,
help="Specify an alternative location to save prediction files.",
)
parser.add_argument(
"--tmp_dir",
type=str,
default="/tmp",
help="Temporary directory for saving merged prediction .h5 file. Default: "
"/tmp",
)
args = parser.parse_args()
######## Configure workflow ########
workflow = CRFPredictWorkflow()
# Prediction parameters
workflow.batch_size = args.batch_size
workflow.num_workers = args.num_workers
workflow.seed = args.seed
workflow.model_config = args.model_config
workflow.best_ckpt = args.best_ckpt
# Data
workflow.prediction_dir = args.prediction_dir
workflow.dtype = args.dtype
workflow.class_weight = args.class_weight
# Save
workflow.output_dir = args.output_dir
workflow.experiment_name = args.experiment_name
workflow.result_dir = args.result_dir
workflow.tmp_dir = args.tmp_dir
# Args
workflow.args = args
######## Model Run ########
workflow.run()