Skip to content

Commit

Permalink
LmDataset, explicit dtype arg
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 23, 2024
1 parent 0859426 commit d5f6aa4
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion returnn/datasets/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
error_on_invalid_seq=True,
add_delayed_seq_data=False,
delayed_seq_data_start_symbol="[START]",
dtype: Optional[str] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -117,6 +118,7 @@ def __init__(
:param bool add_delayed_seq_data: will add another data-key "delayed" which will have the sequence.
delayed_seq_data_start_symbol + original_sequence[:-1].
:param str delayed_seq_data_start_symbol: used for add_delayed_seq_data.
:param dtype: explicit dtype. if not given, automatically determined based on the number of labels.
"""
super(LmDataset, self).__init__(**kwargs)

Expand Down Expand Up @@ -245,7 +247,9 @@ def __init__(
assert not orth_replace_map_file

num_labels = len(self.labels["data"])
if num_labels <= 2**7:
if dtype:
self.dtype = dtype
elif num_labels <= 2**7:
self.dtype = "int8"
elif num_labels <= 2**8:
self.dtype = "uint8"
Expand Down

0 comments on commit d5f6aa4

Please sign in to comment.