-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalign.py
61 lines (48 loc) · 2.06 KB
/
align.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import tempfile
import subprocess
from typing import List
from utils import remove_nonalphanumeric
def align_mwer_segmenter(ref: List[str], hyp: List[str], mwer_segmenter_path: str):
"""
Re-segments hyp file according to ref segmentation with mwerSegmenter.
:param ref: list of reference translations
:param hyp: list of hypothesis translations
:param mwer_segmenter_path: path to the mwerSegmenter executable
:return:
"""
temp_dir = tempfile.TemporaryDirectory()
clean_ref = [" ".join(remove_nonalphanumeric(l.lower()).split()) + "\n" for l in ref]
clean_hyp = [" ".join(remove_nonalphanumeric(l.lower()).split()) + "\n" for l in hyp]
ref_path = os.path.join(temp_dir.name, "ref.txt")
with open(ref_path, "w", encoding="utf-8") as f:
f.writelines(clean_ref)
hyp_path = os.path.join(temp_dir.name, "hyp.txt")
with open(hyp_path, "w", encoding="utf-8") as f:
f.writelines(clean_hyp)
p = subprocess.Popen(f"{mwer_segmenter_path} -hypfile {hyp_path} -mref {ref_path}", shell=True)
ret_code = p.wait()
if ret_code != 0:
raise RuntimeError(f"Failed to align hyp\nmwerSegmenter return code: {ret_code}")
with open("__segments", "r", encoding="utf-8") as f:
aligned_hyp = f.readlines()
# restore punctuation and capitalization
hyp_words = " ".join(hyp).split()
w_idx = 0
clean_aligned_hyp = []
for l in aligned_hyp:
words = l.split()
clean_words = []
for w in words:
if all([not c.isalpha() for c in hyp_words[w_idx]]): # word does not contain letters
clean_words.append(hyp_words[w_idx])
w_idx += 1
if remove_nonalphanumeric(w.lower()) == remove_nonalphanumeric(hyp_words[w_idx].lower()):
clean_words.append(hyp_words[w_idx])
w_idx += 1
continue
else:
raise RuntimeError("Word mismatch after alignment")
clean_aligned_hyp.append(" ".join(clean_words))
temp_dir.cleanup()
return clean_aligned_hyp