-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathseparate.py
145 lines (123 loc) · 5.08 KB
/
separate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/env python
import yaml
import argparse
import torch as th
import numpy as np
from pathlib import Path
from nnet import supported_nnet
from executor.executor import Executor
from utils.audio_util import WaveReader, write_wav
from utils.mvdr_util import make_mvdr
class EgsReader(object):
"""
Egs reader
"""
def __init__(self,
mix_scp,
sr=16000):
self.mix_reader = WaveReader(mix_scp, sr=sr)
def __len__(self):
return len(self.mix_reader)
def __iter__(self):
for key, mix in self.mix_reader:
egs = dict()
egs["mix"] = mix
yield key, egs
class Separator(object):
"""
A simple wrapper for speech separation
"""
def __init__(self, cpt_dir, get_mask=False, device_id=-1):
# load executor
cpt_dir = Path(cpt_dir)
self.get_mask = get_mask
self.executor = self._load_executor(cpt_dir)
cpt_ptr = cpt_dir / "best.pt.tar"
epoch = self.executor.resume(cpt_ptr.as_posix())
print(f"Load checkpoint at {cpt_dir}, on epoch {epoch}")
print(f"Nnet summary: {self.executor}")
if device_id < 0:
self.device = th.device("cpu")
else:
self.device = th.device(f"cuda:{device_id:d}")
self.executor.to(self.device)
self.executor.eval()
def separate(self, egs, early_exit_threshold=0):
"""
Do separation
"""
egs["mix"] = th.from_numpy(egs["mix"][None, :]).to(self.device, non_blocking=True)
with th.no_grad():
spks = self.executor(egs, early_exit_threshold=early_exit_threshold, record=True)
spks = [s.detach().squeeze().cpu().numpy() for s in spks]
return spks
def _load_executor(self, cpt_dir):
"""
Load executor from checkpoint
"""
with open(cpt_dir / "train.yaml", "r") as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
nnet_type = conf["nnet_type"]
if nnet_type not in supported_nnet:
raise RuntimeError(f"Unknown network type: {nnet_type}")
nnet = supported_nnet[nnet_type](**conf["nnet_conf"])
executor = Executor(nnet, extractor_kwargs=conf["extractor_conf"], get_mask=self.get_mask)
return executor
def run(args):
# egs reader
egs_reader = EgsReader(args.mix_scp, sr=args.sr)
# separator
seperator = Separator(args.checkpoint, device_id=args.device_id, get_mask=args.mvdr)
dump_dir = Path(args.dump_dir)
dump_dir.mkdir(exist_ok=True, parents=True)
print(f"Start Separation " + ("w/ mvdr" if args.mvdr else "w/o mvdr"))
for key, egs in egs_reader:
print(f"Processing utterance {key}...")
mixed = egs["mix"]
spks = seperator.separate(egs, early_exit_threshold=args.early_exit_threshold)
if args.mvdr:
res1, res2 = make_mvdr(np.asfortranarray(mixed.T), spks)
spks = [res1, res2]
for i, s in enumerate(spks):
if i < args.num_spks:
write_wav(dump_dir / f"{key}_{i}.wav", s * 0.9 / np.max(np.abs(s)))
print(f"Exit layers: {seperator.executor.nnet.transformer.inference_exit_layers}")
print(f"Avg. exit layer: {sum(seperator.executor.nnet.transformer.inference_exit_layers) * 1.0 / len(seperator.executor.nnet.transformer.inference_exit_layers)}")
print(f"Inference times: {seperator.executor.inference_time}")
print(f"Avg. inference time: {sum(seperator.executor.inference_time) * 1.0 / len(seperator.executor.inference_time)}")
print(f"Processed {len(egs_reader)} utterances done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Command to do speech separation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--checkpoint", type=str, help="Directory of checkpoint")
parser.add_argument("--mix-scp",
type=str,
required=True,
help="Rspecifier for mixed audio")
parser.add_argument("--num_spks",
type=int,
default=2,
help="Number of the speakers")
parser.add_argument("--device-id",
type=int,
default=-1,
help="GPU-id to offload model to, -1 means running on CPU")
parser.add_argument("--sr",
type=int,
default=16000,
help="Sample rate for mixture input")
parser.add_argument("--dump-dir",
type=str,
default="sep",
help="Directory to dump separated speakers")
parser.add_argument("--mvdr",
type=bool,
default=False,
help="apply mvdr")
parser.add_argument("--early_exit_threshold",
type=float,
default=0,
help="Threshold for the early exit mechanism")
args = parser.parse_args()
run(args)