Skip to content

Commit

Permalink
added script to compute Empirical Fisher
Browse files Browse the repository at this point in the history
  • Loading branch information
varisd committed Aug 17, 2018
1 parent a2fa6bc commit 551f8ae
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions scripts/compute_fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python3
"""Compute the Empirical Fisher matrix using a list of gradients.
The gradient tensors can be spread over multiple npz files. The mean
is computed over the first dimension (supposed to be a batch).
"""

import argparse
import os
import re
import glob

import numpy as np

from neuralmonkey.logging import log as _log


def log(message: str, color: str = "blue") -> None:
_log(message, color)


def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--file_prefix", type=str,
help="prefix of the npz files containing the gradients")
parser.add_argument("--output_path", type=str,
help="Path to output the Empirical Fisher to.")
args = parser.parse_args()

output_dict = {}
n = 0
for file in glob.glob("{}.*npz".format(args.file_prefix)):
log("Processing {}".format(file))
tensors = np.load(file)

# first dimension must be equal for all tensors (batch)
shapes = [tensors[f].shape for f in tensors.files]
assert all([x[0] == shapes[0][0] for x in shapes])

for varname in tensors.files:
res = np.sum(np.square(tensors[varname]), 0)
if varname in output_dict:
output_dict[varname] += res
else:
output_dict[varname] = res
n += shapes[0][0]

for name in output_dict:
output_dict[name] /= n

np.savez(args.output_path, **output_dict)


if __name__ == "__main__":
main()

0 comments on commit 551f8ae

Please sign in to comment.