diff --git a/nn/attention.py b/nn/attention.py index 4d6e04c1..bfd60a2a 100644 --- a/nn/attention.py +++ b/nn/attention.py @@ -86,6 +86,21 @@ def __init__( self.qkv_dim_total = 2 * key_dim_total + value_dim_total self.qkv_dim_per_head = 2 * self.key_dim_per_head + self.value_dim_per_head self.qkv = nn.Linear(in_dim, self.qkv_dim_total, with_bias=with_bias) + # In Fairseq MultiheadAttention, they use: + # nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) (same for q_proj, v_proj), + # where xavier_uniform_ means: + # std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + # a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + # _no_grad_uniform_(tensor, -a, a) + # Out nn.init.VarianceScaling with mode="fan_avg", distribution="uniform": + # scale = scale * 2.0 / float(fan_in + fan_out) + # limit = math.sqrt(3.0 * scale) + # nn.random(distribution="uniform", minval=-limit, maxval=limit, ...) + # Our fan_out is 3 times larger than in Fairseq, because we concatenate q,k,v. + # Assuming fan_in = fan_out, it means a factor 2 in the denominator. + # So our default (Glorot, which is VarianceScaling with mode="fan_avg", distribution="uniform", scale=1.0) + # is already the same as Fairseq. + # The bias init is different, but not sure how important this is. if proj_dim: self.proj = nn.Linear(value_dim_total, proj_dim, with_bias=with_bias) else: