Skip to content

Commit

Permalink
RF Trafo enc/dec, rename layer arg, support build dict, layer_opts
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 9, 2024
1 parent a98f840 commit 55c6a02
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 24 deletions.
43 changes: 29 additions & 14 deletions returnn/frontend/decoder/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ def __init__(
num_heads: int = 8,
att_dropout: float = 0.1,
norm: Union[type, Dict[str, Any], rf.Module, Callable] = rf.LayerNorm,
decoder_layer: Optional[Union[TransformerDecoderLayer, rf.Module, type, Any]] = None,
decoder_layer_opts: Optional[Dict[str, Any]] = None,
layer: Optional[Union[TransformerDecoderLayer, rf.Module, type, Dict[str, Any], Any]] = None,
layer_opts: Optional[Dict[str, Any]] = None,
embed_dim: Optional[Dim] = None,
share_embedding: bool = None,
input_embedding_scale: float = None,
input_dropout: float = None,
logits_with_bias: bool = False,
sequential=rf.Sequential,
**compat_kwargs,
):
"""
:param encoder_dim: for cross-attention. None if no cross-attention.
Expand All @@ -67,8 +68,8 @@ def __init__(
:param num_heads: the number of attention heads
:param att_dropout: attention dropout value
:param norm: pre-normalization for FF and attention blocks
:param decoder_layer: an instance of :class:`TransformerDecoderLayer` or similar
:param decoder_layer_opts: options for the encoder layer
:param layer: an instance of :class:`TransformerDecoderLayer` or similar
:param layer_opts: options for the decoder layer
:param embed_dim: if given, will first have an embedding [vocab,embed] and then a linear [embed,model].
:param share_embedding:
:param input_embedding_scale:
Expand All @@ -78,6 +79,16 @@ def __init__(
"""
super().__init__()

if compat_kwargs:
if "decoder_layer" in compat_kwargs: # compatibility, we used to have this before
assert layer is None
layer = compat_kwargs.pop("decoder_layer")
if "decoder_layer_opts" in compat_kwargs: # compatibility, we used to have this before
assert layer_opts is None
layer_opts = compat_kwargs.pop("decoder_layer_opts")
if compat_kwargs:
raise TypeError(f"unexpected kwargs {compat_kwargs!r}")

if not isinstance(vocab_dim, Dim):
raise TypeError(f"TransformerDecoder: unexpected vocab_dim {vocab_dim!r} type {type(vocab_dim)}")
if isinstance(model_dim, int):
Expand Down Expand Up @@ -136,8 +147,8 @@ def __init__(
input_dropout = dropout if BehaviorVersion.get() >= 20 else 0.0
self.input_dropout = input_dropout

if not decoder_layer or isinstance(decoder_layer, type):
decoder_layer_opts_ = dict(
if not layer or isinstance(layer, (dict, type)):
layer_opts_ = dict(
encoder_dim=encoder_dim,
out_dim=model_dim,
ff=ff,
Expand All @@ -148,16 +159,20 @@ def __init__(
att_dropout=att_dropout,
norm=norm,
)
if decoder_layer_opts:
decoder_layer_opts_.update(decoder_layer_opts)
if not decoder_layer:
decoder_layer = TransformerDecoderLayer(**decoder_layer_opts_)
elif isinstance(decoder_layer, type):
decoder_layer = decoder_layer(**decoder_layer_opts_)
layer_opts_ = {k: v for (k, v) in layer_opts_.items() if v is not NotSpecified}
if layer_opts:
layer_opts_.update(layer_opts)
if not layer:
layer = TransformerDecoderLayer(**layer_opts_)
elif isinstance(layer, type):
layer = layer(**layer_opts_)
elif isinstance(layer, dict):
layer_opts_ = {k: v for (k, v) in layer_opts_.items() if k not in layer}
layer = rf.build_from_dict(layer, **layer_opts_)
else:
raise TypeError(f"unexpected decoder_layer {decoder_layer!r}")
raise TypeError(f"unexpected layer {layer!r}")

self.layers = sequential(_copy.deepcopy(decoder_layer) for _ in range(num_layers))
self.layers = sequential(_copy.deepcopy(layer) for _ in range(num_layers))

self.final_layer_norm = make_norm(norm, model_dim)

Expand Down
36 changes: 26 additions & 10 deletions returnn/frontend/encoder/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def __init__(
num_heads: int = 8,
att_dropout: float = 0.1,
norm: Union[type, Dict[str, Any], rf.Module, Callable] = rf.LayerNorm,
decoder_layer: Optional[Union[TransformerEncoderLayer, rf.Module, type, Any]] = None,
layer: Optional[Union[TransformerEncoderLayer, rf.Module, type, Dict[str, Any], Any]] = None,
layer_opts: Optional[Dict[str, Any]] = None,
embed_dim: Optional[Dim] = None,
input_embedding_scale: float = None,
input_dropout: float = None,
sequential=rf.Sequential,
**compat_kwargs,
):
"""
:param vocab_dim:
Expand All @@ -48,14 +50,22 @@ def __init__(
:param num_heads: the number of attention heads
:param att_dropout: attention dropout value
:param norm: pre-normalization for FF and attention blocks
:param decoder_layer: an instance of :class:`TransformerDecoderLayer` or similar
:param layer: an instance of :class:`TransformerEncoderLayer` or similar
:param layer_opts: options for the encoder layer
:param embed_dim: if given, will first have an embedding [vocab,embed] and then a linear [embed,model].
:param input_embedding_scale:
:param input_dropout:
:param sequential:
"""
super().__init__()

if compat_kwargs:
if "decoder_layer" in compat_kwargs: # compatibility, we (weirdly) used to have this before
assert layer is None
layer = compat_kwargs.pop("decoder_layer")
if compat_kwargs:
raise TypeError(f"unexpected kwargs {compat_kwargs!r}")

if not isinstance(vocab_dim, Dim):
raise TypeError(f"TransformerDecoder: unexpected vocab_dim {vocab_dim!r} type {type(vocab_dim)}")
if isinstance(model_dim, int):
Expand Down Expand Up @@ -97,23 +107,29 @@ def __init__(
input_dropout = dropout
self.input_dropout = input_dropout

if not decoder_layer or isinstance(decoder_layer, type):
decoder_layer_opts_ = dict(
if not layer or isinstance(layer, (dict, type)):
layer_opts_ = dict(
out_dim=model_dim,
ff=ff,
dropout=dropout,
num_heads=num_heads,
att_dropout=att_dropout,
norm=norm,
)
if not decoder_layer:
decoder_layer = TransformerEncoderLayer(**decoder_layer_opts_)
elif isinstance(decoder_layer, type):
decoder_layer = decoder_layer(**decoder_layer_opts_)
layer_opts_ = {k: v for (k, v) in layer_opts_.items() if v is not NotSpecified}
if layer_opts:
layer_opts_.update(layer_opts)
if not layer:
layer = TransformerEncoderLayer(**layer_opts_)
elif isinstance(layer, type):
layer = layer(**layer_opts_)
elif isinstance(layer, dict):
layer_opts_ = {k: v for (k, v) in layer_opts_.items() if k not in layer}
layer = rf.build_from_dict(layer, **layer_opts_)
else:
raise TypeError(f"unexpected decoder_layer {decoder_layer!r}")
raise TypeError(f"unexpected layer {layer!r}")

self.layers = sequential(_copy.deepcopy(decoder_layer) for _ in range(num_layers))
self.layers = sequential(_copy.deepcopy(layer) for _ in range(num_layers))

self.final_layer_norm = make_norm(norm, model_dim)

Expand Down

0 comments on commit 55c6a02

Please sign in to comment.