Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Tiny tango tweaks (#5383)
Browse files Browse the repository at this point in the history
* Fixes SqliteSparseSequence for empty extends

* Fixes type annotation

* Gives a length to the transformer text field

* Test for empty extends

* Changelog

* Formatting

* Fixes the sqlite format

* Get output dims

* Fixes test
  • Loading branch information
dirkgr authored Aug 30, 2021
1 parent 2895021 commit 60213cd
Show file tree
Hide file tree
Showing 10 changed files with 24 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added a way for AllenNLP Tango to read and write datasets lazily.
- Added a way to remix datasets flexibly
- Added `from_pretrained_transformer_and_instances` constructor to `Vocabulary`
- `TransformerTextField` now supports `__len__`.

### Fixed

Expand All @@ -50,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- The `MultiProcessDataLoader` will properly shutdown its workers when a `SIGTERM` is received.
- Fixed the way names are applied to Tango `Step` instances.
- Fixed a bug in calculating loss in the distributed setting.
- Fixed a bug when extending a sparse sequence by 0 items.

### Changed

Expand Down
3 changes: 3 additions & 0 deletions allennlp/common/sqlite_sparse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ def __delitem__(self, i: Union[int, slice]):

def extend(self, values: Iterable[Any]) -> None:
current_length = len(self)
index = -1
for index, value in enumerate(values):
self.table[str(index + current_length)] = value
if index < 0:
return
self.table["_len"] = current_length + index + 1
self.table.commit()

Expand Down
3 changes: 3 additions & 0 deletions allennlp/data/fields/transformer_text_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,6 @@ def readable_tensor(t: torch.Tensor) -> str:
for name in self.__slots__
if isinstance(getattr(self, name), torch.Tensor)
}

def __len__(self):
return len(self.input_ids)
3 changes: 3 additions & 0 deletions allennlp/modules/transformer/output_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def __init__(self, input_size: int, hidden_size: int, dropout: float):
self.layer_norm = LayerNorm(hidden_size, eps=1e-12)
self.dropout = torch.nn.Dropout(dropout)

def get_output_dim(self) -> int:
return self.dense.out_features

def forward(self, hidden_states, input_tensor):
dense_output = self.dense(hidden_states)
dropout_output = self.dropout(dense_output)
Expand Down
6 changes: 3 additions & 3 deletions allennlp/modules/transformer/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,6 @@ def __init__(
):
super().__init__()

self._hidden_size = hidden_size
self._add_cross_attention = add_cross_attention

self.attention = AttentionLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
Expand All @@ -186,6 +183,9 @@ def __init__(
input_size=intermediate_size, hidden_size=hidden_size, dropout=hidden_dropout
)

def get_output_dim(self) -> int:
return self.output.get_output_dim()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions allennlp/modules/transformer/transformer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __init__(
)
self.layers = replicate_layers(layer, num_hidden_layers)

def get_output_dim(self) -> int:
return self.layers[-1].get_output_dim()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions allennlp/tango/sqlite_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def write(self, artifact: DatasetDict, dir: Union[str, PathLike]):
filename = f"{split_name}.sqlite"
if not filename_is_safe(filename):
raise ValueError(f"{split_name} is not a valid name for a split.")
(dir / filename).unlink(missing_ok=True)
if isinstance(split, SqliteSparseSequence):
split.copy_to(filename)
split.copy_to(dir / filename)
else:
(dir / filename).unlink(missing_ok=True)
sqlite = SqliteSparseSequence(dir / filename)
sqlite.extend(split)

Expand Down
4 changes: 2 additions & 2 deletions allennlp/tango/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def __init__(
self.only_if_needed = only_if_needed

self.work_dir_for_run: Optional[
PathLike
Path
] = None # This is set only while the run() method runs.

@classmethod
Expand Down Expand Up @@ -496,7 +496,7 @@ def _run_with_work_dir(self, cache: StepCache, **kwargs) -> T:
# No cleanup, as we want to keep the directory for restarts or serialization.
self.work_dir_for_run = None

def work_dir(self) -> PathLike:
def work_dir(self) -> Path:
"""
Returns a work directory that a step can use while its `run()` method runs.
Expand Down
2 changes: 2 additions & 0 deletions tests/common/sqlite_sparse_sequence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ def test_sqlite_sparse_sequence():
with TemporaryDirectory(prefix="test_sparse_sequence-") as temp_dir:
s = SqliteSparseSequence(os.path.join(temp_dir, "test.sqlite"))
assert len(s) == 0
s.extend([])
assert len(s) == 0
s.append("one")
assert len(s) == 1
s.extend(["two", "three"])
Expand Down
2 changes: 1 addition & 1 deletion tests/modules/transformer/transformer_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_loading_from_pretrained(pretrained_model_name):

batch_size = 2
seq_length = 15
hidden_size = transformer_stack.layers[0]._hidden_size
hidden_size = transformer_stack.layers[0].get_output_dim()

hidden_states = torch.randn(batch_size, seq_length, hidden_size)
attention_mask = torch.randint(0, 2, (batch_size, seq_length))
Expand Down

0 comments on commit 60213cd

Please sign in to comment.