Skip to content

Commit

Permalink
RF test_causal_self_att_variants_single_step_vs_full_seq
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 10, 2024
1 parent 660838c commit b091bd6
Showing 1 changed file with 79 additions and 1 deletion.
80 changes: 79 additions & 1 deletion tests/test_rf_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from __future__ import annotations
from typing import Tuple
from typing import Union, Tuple
import numpy as np
import numpy.testing
import _setup_test_env # noqa
Expand Down Expand Up @@ -440,6 +440,84 @@ def test_rope_causal_self_att():
print(" all matched!")


def test_causal_self_att_variants_single_step_vs_full_seq():
from returnn.tensor import single_step_dim

time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(7 * 2, name="in")
extern_data = TensorDict(
{
"data": Tensor("data", [batch_dim, time_dim, in_dim], dtype="float32"),
}
)

# noinspection PyShadowingNames
def _forward_step(*, model: Union[rf.CausalSelfAttention], extern_data: TensorDict):
x = extern_data["data"]

out_seq_level, _ = model(x, axis=time_dim)
out_seq_level.mark_as_output("out_seq_level", shape=[batch_dim, time_dim, model.out_dim])

out_seq_level_explicit_initial_state, _ = model(
x, axis=time_dim, state=model.default_initial_state(batch_dims=[batch_dim])
)
out_seq_level_explicit_initial_state.mark_as_output(
"out_seq_level_explicit_initial_state", shape=[batch_dim, time_dim, model.out_dim]
)

def _body(
_x: Tensor, _state: Union[rf.CausalSelfAttentionState]
) -> Tuple[Tensor, Union[rf.CausalSelfAttentionState]]:
return model(_x, axis=single_step_dim, state=_state)

out_single_steps, _, _ = rf.scan(
spatial_dim=time_dim,
xs=x,
body=_body,
ys=Tensor("y", dims=[batch_dim, model.out_dim], dtype="float32"),
initial=model.default_initial_state(batch_dims=[batch_dim]),
)
out_single_steps.mark_as_output("out_single_steps", shape=[batch_dim, time_dim, model.out_dim])

common_opts = dict(
in_dim=in_dim,
proj_dim=Dim(5, name="out"),
key_dim_total=Dim(21 * 2, name="key-dim-total"),
value_dim_total=Dim(33, name="value-dim-total"),
num_heads=3,
)

def _make_causal_self_att(**_kwargs):
return rf.CausalSelfAttention(**common_opts)

def _make_rope_causal_self_att(**_kwargs):
return rf.RotaryPosCausalSelfAttention(**common_opts)

def _make_rel_pos_causal_self_att(**_kwargs):
return rf.RelPosCausalSelfAttention(**common_opts)

models = [_make_causal_self_att, _make_rope_causal_self_att, _make_rel_pos_causal_self_att]

for get_model in models:
print("> Testing model:", get_model.__name__)
res = run_model(
extern_data,
get_model,
_forward_step,
# TF needs TensorArray unstack, not implemented yet
test_tensorflow=False,
)

# Check that the single-step and the seq-level output are the same.
res_seq_level = res.data["out_seq_level"].raw_tensor
for key in ["out_seq_level_explicit_initial_state", "out_single_steps"]:
res_other = res.data[key].raw_tensor
assert res_seq_level.shape == res_other.shape
numpy.testing.assert_allclose(
res_other, res_seq_level, atol=1e-5, rtol=1e-5, err_msg=f"output {key} differs"
)


def test_relative_positional_encoding():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(8, name="in")
Expand Down

0 comments on commit b091bd6

Please sign in to comment.