Skip to content

Commit

Permalink
Replace iterator class with generator.
Browse files Browse the repository at this point in the history
  • Loading branch information
jzohrab committed Jan 9, 2025
1 parent 4e168d2 commit 2a7fb85
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 50 deletions.
59 changes: 16 additions & 43 deletions lute/book/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,57 +10,31 @@
)


class SentenceGroupIterator:
def sentence_group_iterator(tokens, maxcount=500):
"""
An iterator of ParsedTokens that groups them by sentence, up
to a maximum number of tokens.
A generator that yields groups of ParsedTokens grouped by sentence,
up to a maximum number of tokens.
"""

def __init__(self, tokens, maxcount=500):
self.tokens = tokens
self.maxcount = maxcount
self.currpos = 0

def count(self):
"""
Get count of groups that will be returned.
"""
old_currpos = self.currpos
c = 0
while self.next():
c += 1
self.currpos = old_currpos
return c

def next(self):
"""
Get next sentence group.
"""
if self.currpos >= len(self.tokens):
return False

currpos = 0
while currpos < len(tokens):
curr_tok_count = 0
last_eos = -1
i = self.currpos
i = currpos

while (curr_tok_count <= self.maxcount or last_eos == -1) and i < len(
self.tokens
):
tok = self.tokens[i]
if tok.is_end_of_sentence == 1:
while (curr_tok_count <= maxcount or last_eos == -1) and i < len(tokens):
tok = tokens[i]
if tok.is_end_of_sentence:
last_eos = i
if tok.is_word == 1:
if tok.is_word:
curr_tok_count += 1
i += 1

if curr_tok_count <= self.maxcount or last_eos == -1:
ret = self.tokens[self.currpos : i]
self.currpos = i + 1
if curr_tok_count <= maxcount or last_eos == -1:
yield tokens[currpos:i]
currpos = i
else:
ret = self.tokens[self.currpos : last_eos + 1]
self.currpos = last_eos + 1

return ret
yield tokens[currpos : last_eos + 1]
currpos = last_eos + 1


class Book: # pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -179,8 +153,7 @@ def _split_by_sentences(self, language, fulltext, max_word_tokens_per_text=250):
pages = []
for segment in self._split_text_at_page_breaks(fulltext):
tokens = language.parser.get_parsed_tokens(segment, language)
it = SentenceGroupIterator(tokens, max_word_tokens_per_text)
while toks := it.next():
for toks in sentence_group_iterator(tokens, max_word_tokens_per_text):
s = (
"".join([t.token for t in toks])
.replace("\r", "")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
SentenceGroupIterator tests.
sentence_group_iterator tests.
"""

from lute.book.model import SentenceGroupIterator
from lute.book.model import sentence_group_iterator
from lute.parse.space_delimited_parser import SpaceDelimitedParser


Expand All @@ -13,17 +13,14 @@ def toks_to_string(tokens):
def test_sgi_scenarios(english):
"""
Given a string and the max token count,
SentenceGroupIterator should return expected groups.
generator should return expected groups.
"""

def scenario(s, maxcount, expected_groups):
parser = SpaceDelimitedParser()
tokens = parser.get_parsed_tokens(s, english)

it = SentenceGroupIterator(tokens, maxcount)
groups = []
while g := it.next():
groups.append(g)
groups = [g for g in sentence_group_iterator(tokens, maxcount)]
gs = [toks_to_string(g) for g in groups]
assert "||".join(gs) == "||".join(
expected_groups
Expand Down

0 comments on commit 2a7fb85

Please sign in to comment.