-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathevaluation.py
More file actions
executable file
·59 lines (40 loc) · 2.06 KB
/
evaluation.py
File metadata and controls
executable file
·59 lines (40 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# coding=utf-8
from __future__ import print_function
import sys
import traceback
from tqdm import tqdm
def decode(examples, model, args, verbose=False, **kwargs):
## TODO: create decoder for each dataset
if verbose:
print('evaluating %d examples' % len(examples))
was_training = model.training
model.eval()
# c=True if cuda.is_available() else False
decode_results = []
count = 0
for example in tqdm(examples, desc='Decoding', file=sys.stdout, total=len(examples)):
# print(example.input_actions)
if not args.cuda:
hyps = model.parse(example.leaves_nodes['leaves'].unsqueeze(dim=0).long(),example.leaves_nodes['nodes'].unsqueeze(dim=0).long(),example.leaves_nodes['spans'].unsqueeze(dim=0),orig_leaves=example.leaves_nodes['orig_leaves'],context=None, beam_size=args.beam_size)
else:
hyps = model.parse(example.leaves_nodes['leaves'].unsqueeze(dim=0).long().cuda(),example.leaves_nodes['nodes'].unsqueeze(dim=0).long().cuda(),example.leaves_nodes['spans'].unsqueeze(dim=0).cuda(),orig_leaves=example.leaves_nodes['orig_leaves'],context=None, beam_size=args.beam_size)
decoded_hyps = []
for hyp_id, hyp in enumerate(hyps):
try:
hyp.code = model.transition_system.ast_to_surface_code(hyp.tree)
decoded_hyps.append(hyp)
except:
print(traceback.format_exc())
count += 1
decode_results.append(decoded_hyps)
if was_training: model.train()
return decode_results
def evaluate(examples, parser, evaluator, args, verbose=False, return_decode_result=False, eval_top_pred_only=False):
examples=examples
decode_results = decode(examples, parser, args, verbose=verbose)
eval_result = evaluator.evaluate_dataset(examples, decode_results, fast_mode=eval_top_pred_only, args=args,transition_system=parser.transition_system)
if return_decode_result:
#TODO: remove this
return eval_result, decode_results
else:
return eval_result