Skip to content

Commit 551f8ae

Browse files
committed
added script to compute Empirical Fisher
1 parent a2fa6bc commit 551f8ae

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

scripts/compute_fisher.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#!/usr/bin/env python3
2+
"""Compute the Empirical Fisher matrix using a list of gradients.
3+
4+
The gradient tensors can be spread over multiple npz files. The mean
5+
is computed over the first dimension (supposed to be a batch).
6+
7+
"""
8+
9+
import argparse
10+
import os
11+
import re
12+
import glob
13+
14+
import numpy as np
15+
16+
from neuralmonkey.logging import log as _log
17+
18+
19+
def log(message: str, color: str = "blue") -> None:
20+
_log(message, color)
21+
22+
23+
def main() -> None:
24+
parser = argparse.ArgumentParser(description=__doc__)
25+
parser.add_argument("--file_prefix", type=str,
26+
help="prefix of the npz files containing the gradients")
27+
parser.add_argument("--output_path", type=str,
28+
help="Path to output the Empirical Fisher to.")
29+
args = parser.parse_args()
30+
31+
output_dict = {}
32+
n = 0
33+
for file in glob.glob("{}.*npz".format(args.file_prefix)):
34+
log("Processing {}".format(file))
35+
tensors = np.load(file)
36+
37+
# first dimension must be equal for all tensors (batch)
38+
shapes = [tensors[f].shape for f in tensors.files]
39+
assert all([x[0] == shapes[0][0] for x in shapes])
40+
41+
for varname in tensors.files:
42+
res = np.sum(np.square(tensors[varname]), 0)
43+
if varname in output_dict:
44+
output_dict[varname] += res
45+
else:
46+
output_dict[varname] = res
47+
n += shapes[0][0]
48+
49+
for name in output_dict:
50+
output_dict[name] /= n
51+
52+
np.savez(args.output_path, **output_dict)
53+
54+
55+
if __name__ == "__main__":
56+
main()

0 commit comments

Comments
 (0)