diff --git a/returnn/datasets/lm.py b/returnn/datasets/lm.py index af5a92ec3..9988bef92 100644 --- a/returnn/datasets/lm.py +++ b/returnn/datasets/lm.py @@ -67,6 +67,7 @@ def __init__( error_on_invalid_seq=True, add_delayed_seq_data=False, delayed_seq_data_start_symbol="[START]", + dtype: Optional[str] = None, **kwargs, ): """ @@ -117,6 +118,7 @@ def __init__( :param bool add_delayed_seq_data: will add another data-key "delayed" which will have the sequence. delayed_seq_data_start_symbol + original_sequence[:-1]. :param str delayed_seq_data_start_symbol: used for add_delayed_seq_data. + :param dtype: explicit dtype. if not given, automatically determined based on the number of labels. """ super(LmDataset, self).__init__(**kwargs) @@ -245,7 +247,9 @@ def __init__( assert not orth_replace_map_file num_labels = len(self.labels["data"]) - if num_labels <= 2**7: + if dtype: + self.dtype = dtype + elif num_labels <= 2**7: self.dtype = "int8" elif num_labels <= 2**8: self.dtype = "uint8"