|
5 | 5 | import sys |
6 | 6 | import pprint |
7 | 7 | import collections |
8 | | -from os.path import abspath, join, exists |
9 | | - |
10 | | -def read_input(path): |
11 | | - with open(path, encoding='utf8') as fp: |
12 | | - return fp.read() |
13 | | - |
14 | | - |
15 | | -def read_phrases(path): |
16 | | - phrases = {} |
17 | | - |
18 | | - with open(path) as fp: |
19 | | - for line in fp: |
20 | | - idx, start, end = line.split() |
21 | | - phrases[int(idx)] = dict( |
22 | | - start=int(start), |
23 | | - end=int(end), |
24 | | - ) |
25 | | - |
26 | | - return phrases |
27 | | - |
28 | | - |
29 | | -def read_labels(path): |
30 | | - labels = {} |
31 | | - |
32 | | - with open(path) as fp: |
33 | | - for line in fp: |
34 | | - idx, label = line.split() |
35 | | - labels[int(idx)] = label |
36 | | - |
37 | | - return labels |
38 | | - |
39 | | - |
40 | | -def read_links(path): |
41 | | - links = [] |
42 | | - |
43 | | - with open(path) as fp: |
44 | | - for line in fp: |
45 | | - line = line.split() |
46 | | - rel = line[0] |
47 | | - |
48 | | - try: |
49 | | - links.append(dict( |
50 | | - rel=rel, |
51 | | - arg1=int(line[1]), |
52 | | - arg2=int(line[2]), |
53 | | - arg3=None, |
54 | | - )) |
55 | | - except: |
56 | | - pass |
57 | | - |
58 | | - return links |
59 | | - |
60 | | - |
61 | | -def find_obj(objs, x): |
62 | | - if isinstance(objs, dict): |
63 | | - source = objs.items() |
64 | | - else: |
65 | | - source = zip(objs, objs) |
66 | | - |
67 | | - for idx, l in source: |
68 | | - for field in l: |
69 | | - if l[field] != x[field]: |
70 | | - break |
71 | | - else: |
72 | | - return idx |
| 8 | +from tools import (get_span, |
| 9 | + read_input, |
| 10 | + read_phrases, |
| 11 | + read_links, |
| 12 | + read_labels, |
| 13 | + compare_phrases, |
| 14 | + compare_links, |
| 15 | + compare_labels) |
73 | 16 |
|
74 | | - return None |
75 | | - |
76 | | - |
77 | | -def between(x, a, b): |
78 | | - return x >= a and x <= b |
79 | | - |
80 | | - |
81 | | -def intersect(x1, y1, x2, y2): |
82 | | - if x1 >= y1: |
83 | | - return False |
84 | | - if x2 >= y2: |
85 | | - return False |
86 | | - |
87 | | - return between(x1, x2, y2) or between(y1, x2, y2) or between(x2, x1, y1) or between(y2, x1, y1) |
88 | | - |
89 | | - |
90 | | -def find_partial(objs, x): |
91 | | - fidx = find_obj(objs, x) |
92 | | - |
93 | | - if fidx: |
94 | | - return fidx, True |
95 | | - |
96 | | - start = x['start'] |
97 | | - end = x['end'] |
98 | | - |
99 | | - for idx, l in objs.items(): |
100 | | - sstart = l['start'] |
101 | | - send = l['end'] |
102 | | - |
103 | | - if intersect(start, end, sstart, send): |
104 | | - return idx, False |
105 | | - |
106 | | - return None, False |
107 | | - |
108 | | - |
109 | | -def sort(items): |
110 | | - return sorted(items, key=lambda x: (x['start'], x['end'])) |
111 | | - |
112 | | - |
113 | | -def compare_phrases(gold_phrases, dev_phrases): |
114 | | - correct = [] |
115 | | - missing = [] |
116 | | - spurious = [] |
117 | | - partial = [] |
118 | | - mapping = {} |
119 | | - |
120 | | - for idx, l in gold_phrases.items(): |
121 | | - fidx, exact = find_partial(dev_phrases, l) |
122 | | - |
123 | | - if fidx and not "eval:%i"%fidx in mapping: |
124 | | - if exact: |
125 | | - correct.append(l) |
126 | | - else: |
127 | | - partial.append((l, dev_phrases[fidx])) |
128 | | - |
129 | | - mapping["ref:%i"%idx] = fidx |
130 | | - mapping["eval:%i"%fidx] = idx |
131 | | - else: |
132 | | - missing.append(l) |
133 | | - |
134 | | - for fidx, l in dev_phrases.items(): |
135 | | - if not "eval:%i"%fidx in mapping: |
136 | | - spurious.append(l) |
137 | | - |
138 | | - return dict( |
139 | | - correct=sort(correct), |
140 | | - missing=sort(missing), |
141 | | - spurious=sort(spurious), |
142 | | - partial=partial, |
143 | | - mapping=mapping, |
144 | | - ) |
145 | | - |
146 | | - |
147 | | -def compare_labels(gold, dev, mapping): |
148 | | - confussion_matrix = collections.defaultdict(lambda: 0) |
149 | | - |
150 | | - correct = [] |
151 | | - incorrect = [] |
152 | | - spurious = [] |
153 | | - missing = [] |
154 | | - |
155 | | - for idx, l in gold.items(): |
156 | | - fidx = mapping.get('ref:%i' % idx) |
157 | | - |
158 | | - if not fidx: |
159 | | - missing.append(dict(id=idx, label=l)) |
160 | | - continue |
161 | | - |
162 | | - if not fidx in dev: |
163 | | - missing.append(dict(id=idx, label=l)) |
164 | | - confussion_matrix[(l, 'None')] += 1 |
165 | | - |
166 | | - l2 = dev[fidx] |
167 | | - confussion_matrix[(l, l2)] += 1 |
168 | | - |
169 | | - if l == l2: |
170 | | - correct.append(dict(fidx=fidx, label=l2)) |
171 | | - else: |
172 | | - incorrect.append(dict(fidx=fidx, label=l2, correct=l)) |
173 | | - |
174 | | - for fidx, l in dev.items(): |
175 | | - if "eval:%i"%fidx in mapping: |
176 | | - continue |
177 | | - |
178 | | - spurious.append(dict(fidx=fidx, label=l)) |
179 | | - |
180 | | - return dict( |
181 | | - confussion_matrix=confussion_matrix, |
182 | | - correct=correct, |
183 | | - incorrect=incorrect, |
184 | | - missing=missing, |
185 | | - spurious=spurious, |
186 | | - ) |
187 | | - |
188 | | - |
189 | | -def map_entities(x, mapping, map_key): |
190 | | - result = dict( |
191 | | - rel=x['rel'], |
192 | | - arg1 = None, |
193 | | - arg2 = None, |
194 | | - arg3 = None, |
195 | | - ) |
196 | | - |
197 | | - for key in ["arg1", "arg2", "arg3"]: |
198 | | - value = x[key] |
199 | | - |
200 | | - if value is None: |
201 | | - result[key] = None |
202 | | - continue |
203 | | - |
204 | | - mapped = map_key+":%i"%value |
205 | | - |
206 | | - if not mapped in mapping: |
207 | | - return False |
208 | | - |
209 | | - result[key] = mapping[mapped] |
210 | | - |
211 | | - return result |
212 | | - |
213 | | - |
214 | | -def find_relation(rel, relations): |
215 | | - for r in relations: |
216 | | - for k in ["rel", "arg1", "arg2", "arg3"]: |
217 | | - if r[k] != rel[k]: |
218 | | - break |
219 | | - else: |
220 | | - return True |
221 | | - |
222 | | - return False |
223 | | - |
224 | | - |
225 | | -def compare_links(gold_links, dev_links, mapping): |
226 | | - correct = [] |
227 | | - missing = [] |
228 | | - spurious = [] |
229 | | - |
230 | | - for rel in gold_links: |
231 | | - mapped = map_entities(rel, mapping, "ref") |
232 | | - |
233 | | - if not mapped: |
234 | | - missing.append(rel) |
235 | | - continue |
236 | | - |
237 | | - if not find_relation(mapped, dev_links): |
238 | | - missing.append(rel) |
239 | | - continue |
240 | | - |
241 | | - correct.append(rel) |
242 | | - |
243 | | - for rel in dev_links: |
244 | | - mapped = map_entities(rel, mapping, "eval") |
245 | | - |
246 | | - if not mapped: |
247 | | - spurious.append(rel) |
248 | | - continue |
249 | | - |
250 | | - if not find_relation(mapped, gold_links): |
251 | | - spurious.append(rel) |
252 | | - |
253 | | - return dict( |
254 | | - correct=correct, |
255 | | - missing=missing, |
256 | | - spurious=spurious, |
257 | | - ) |
258 | | - |
259 | | - |
260 | | -def get_span(sentences, obj): |
261 | | - return sentences[obj["start"]:obj["end"]] |
| 17 | +from os.path import abspath, join, exists |
262 | 18 |
|
263 | 19 |
|
264 | 20 | def evaluate_phrases(input_file, gold_phrases_file, dev_phrases_file): |
|
0 commit comments