Source code for olympus.models.inits.kaiming

import torch.nn

from olympus.models.inits.base import Initialization


[docs]class Kaiming(Initialization): """ References ---------- .. [1] Xavier Glorot, Yoshua Bengio, "Understanding the difficulty of training deep feedforward neural networks" .. [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification", Feb 2015 """ def __call__(self, model): """Init model using given function for Linear and Conv2d, and {0, 1} for BatchNorm.""" # TODO: detect_non_linearities and pass relu or leaky_relu according to architecture. for m in model.modules(): if isinstance(m, (torch.nn.Linear, torch.nn.Conv2d)): self.layer_init(m.weight, self.non_linearity) if m.bias is not None: torch.nn.init.constant_(m.bias, 0.0) elif isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d)): if m.affine: torch.nn.init.constant_(m.weight, 1.0) torch.nn.init.constant_(m.bias, 0.0) return model
[docs]class KaimingUniform(Kaiming): """See :class`.Kaiming`""" def __init__(self, a, mode, non_linearity): self.a = a self.mode = mode self.non_linearity = non_linearity
[docs] def layer_init(self, weight): torch.nn.init.kaiming_uniform_(weight, self.a, self.mode, self.non_linearity)
[docs] @staticmethod def get_space(): return { 'a': 'uniform(0, 1)', 'mode': 'choices([fan_in, fan_out])', 'non_linearity': 'choices([leaky_relu, relu])' }
[docs] @staticmethod def defaults(): return { 'a': 0, 'mode': 'fan_in', 'non_linearity': 'leaky_relu' }
[docs]class KaimingNormal(Initialization): """See :class`.Kaiming`""" def __init__(self, a, mode, non_linearity): self.a = a self.mode = mode self.non_linearity = non_linearity
[docs] def layer_init(self, weight): torch.nn.init.kaiming_normal_(weight, self.a, self.mode, self.non_linearity)
[docs] @staticmethod def get_space(): return { 'a': 'uniform(0, 1)', 'mode': 'choices([fan_in, fan_out])', 'non_linearity': 'choices([leaky_relu, relu])' }
[docs] @staticmethod def defaults(): return { 'a': 0, 'mode': 'fan_in', 'non_linearity': 'leaky_relu' }
builders = { 'kinit_uniform': KaimingUniform, 'kinit_normal': KaimingNormal}