From b091bd62699111d0a3a5f6a12413129bb93b967d Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 10 Dec 2024 11:16:36 +0100 Subject: [PATCH] RF test_causal_self_att_variants_single_step_vs_full_seq --- tests/test_rf_attention.py | 80 +++++++++++++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/tests/test_rf_attention.py b/tests/test_rf_attention.py index 01357c51e..94f61d193 100644 --- a/tests/test_rf_attention.py +++ b/tests/test_rf_attention.py @@ -3,7 +3,7 @@ """ from __future__ import annotations -from typing import Tuple +from typing import Union, Tuple import numpy as np import numpy.testing import _setup_test_env # noqa @@ -440,6 +440,84 @@ def test_rope_causal_self_att(): print(" all matched!") +def test_causal_self_att_variants_single_step_vs_full_seq(): + from returnn.tensor import single_step_dim + + time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) + in_dim = Dim(7 * 2, name="in") + extern_data = TensorDict( + { + "data": Tensor("data", [batch_dim, time_dim, in_dim], dtype="float32"), + } + ) + + # noinspection PyShadowingNames + def _forward_step(*, model: Union[rf.CausalSelfAttention], extern_data: TensorDict): + x = extern_data["data"] + + out_seq_level, _ = model(x, axis=time_dim) + out_seq_level.mark_as_output("out_seq_level", shape=[batch_dim, time_dim, model.out_dim]) + + out_seq_level_explicit_initial_state, _ = model( + x, axis=time_dim, state=model.default_initial_state(batch_dims=[batch_dim]) + ) + out_seq_level_explicit_initial_state.mark_as_output( + "out_seq_level_explicit_initial_state", shape=[batch_dim, time_dim, model.out_dim] + ) + + def _body( + _x: Tensor, _state: Union[rf.CausalSelfAttentionState] + ) -> Tuple[Tensor, Union[rf.CausalSelfAttentionState]]: + return model(_x, axis=single_step_dim, state=_state) + + out_single_steps, _, _ = rf.scan( + spatial_dim=time_dim, + xs=x, + body=_body, + ys=Tensor("y", dims=[batch_dim, model.out_dim], dtype="float32"), + initial=model.default_initial_state(batch_dims=[batch_dim]), + ) + out_single_steps.mark_as_output("out_single_steps", shape=[batch_dim, time_dim, model.out_dim]) + + common_opts = dict( + in_dim=in_dim, + proj_dim=Dim(5, name="out"), + key_dim_total=Dim(21 * 2, name="key-dim-total"), + value_dim_total=Dim(33, name="value-dim-total"), + num_heads=3, + ) + + def _make_causal_self_att(**_kwargs): + return rf.CausalSelfAttention(**common_opts) + + def _make_rope_causal_self_att(**_kwargs): + return rf.RotaryPosCausalSelfAttention(**common_opts) + + def _make_rel_pos_causal_self_att(**_kwargs): + return rf.RelPosCausalSelfAttention(**common_opts) + + models = [_make_causal_self_att, _make_rope_causal_self_att, _make_rel_pos_causal_self_att] + + for get_model in models: + print("> Testing model:", get_model.__name__) + res = run_model( + extern_data, + get_model, + _forward_step, + # TF needs TensorArray unstack, not implemented yet + test_tensorflow=False, + ) + + # Check that the single-step and the seq-level output are the same. + res_seq_level = res.data["out_seq_level"].raw_tensor + for key in ["out_seq_level_explicit_initial_state", "out_single_steps"]: + res_other = res.data[key].raw_tensor + assert res_seq_level.shape == res_other.shape + numpy.testing.assert_allclose( + res_other, res_seq_level, atol=1e-5, rtol=1e-5, err_msg=f"output {key} differs" + ) + + def test_relative_positional_encoding(): time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) in_dim = Dim(8, name="in")