Source code for olympus.dashboard.plots.layer

import math
import torch.nn as nn
import matplotlib.pyplot as plt

plt.style.use('dark_background')


[docs]def vizualize_param(p, ax=None): fig = None if ax is None: fig, ax = plt.subplots() def heatmap(data): im = ax.imshow(data.numpy()) cbar = ax.figure.colorbar(im, ax=ax) cbar.ax.set_ylabel('', rotation=-90, va="bottom") def histogram(data): ax.bar(x=range(len(data)), height=data) p = p.squeeze() if len(p.shape) == 3: for i in range(p.shape[0]): heatmap(p[i, :, :].detach()) elif len(p.shape) == 4: for i in range(p.shape[0]): for j in range(p.shape[1]): heatmap(p[i, j, :, :].detach()) elif len(p.shape) == 2: heatmap(p[:, :].detach()) elif len(p.shape) == 1: histogram(p[:].detach()) return fig
[docs]def vizualize_weights(module: nn.Module): params = list(module.parameters()) n = 0 for p in params: shape = p.shape if len(shape) == 3: n += shape[0] elif len(shape) == 4: n += shape[0] * shape[1] elif len(shape) > 4: print('skip high dimension weight ', shape) else: n += 1 if n < 4: row = n col = 1 else: n = int(math.sqrt(n)) + 1 row = n col = n fig = plt.figure() gs = fig.add_gridspec(row, col) k = 0 def get_col(k): if col > 1: return k % row return 0 def get_row(k): if row > 1: return k // col return 0 for p in params: ax = fig.add_subplot(gs[get_row(k), get_col(k)]) vizualize_param(p, ax) k += 1 return fig
if __name__ == '__main__': fig = vizualize_weights(nn.Conv2d(3, 10, 3)) plt.show()