From 075315430ed37f02622abfe2a83c6914300c0832 Mon Sep 17 00:00:00 2001 From: John Baublitz Date: Mon, 9 Sep 2024 14:59:40 +0200 Subject: [PATCH] Improve database schema and usability --- language-practice | 6 +- language_practice/config.py | 9 +- language_practice/gui.py | 76 +++--- language_practice/repetition.py | 3 +- language_practice/sqlite.py | 398 ++++++++++++++++++++++---------- language_practice/terminal.py | 112 ++++++--- 6 files changed, 420 insertions(+), 184 deletions(-) diff --git a/language-practice b/language-practice index 60f4ebb..81b8b64 100755 --- a/language-practice +++ b/language-practice @@ -11,9 +11,7 @@ Inflection charts are pulled from wiktionary. import argparse import sys -from language_practice.gui import GuiApplication from language_practice.sqlite import SqliteHandle -from language_practice.terminal import TerminalApplication class Once(argparse.Action): @@ -53,9 +51,11 @@ def main(): try: all_sets = handle.get_all_sets() if args.gui: - gui = GuiApplication(handle, all_sets) + from language_practice.gui import GuiApplication # pylint: disable=import-outside-toplevel + gui = GuiApplication(handle, all_sets, application_id="me.jbaublitz.LanguagePractice") gui.run() else: + from language_practice.terminal import TerminalApplication # pylint: disable=import-outside-toplevel tui = TerminalApplication(handle, all_sets) tui.run() except Exception as err: # pylint: disable=broad-exception-caught diff --git a/language_practice/config.py b/language_practice/config.py index c4a3523..8ebe712 100644 --- a/language_practice/config.py +++ b/language_practice/config.py @@ -16,6 +16,7 @@ class Entry: """ # pylint: disable=too-many-arguments + # pylint: disable=too-many-positional-arguments def __init__( self, word: str, @@ -24,7 +25,7 @@ def __init__( aspect: str | None, usage: str | None, part_of_speech: str | None, - charts: list[list[list[str]]] | None, + charts: list[list[str]] | None, repetition: WordRepetition, ): self.word = word @@ -72,7 +73,7 @@ def get_part_of_speech(self) -> str | None: """ return self.part_of_speech - def get_charts(self) -> list[list[list[str]]] | None: + def get_charts(self) -> list[list[str]] | None: """ Get charts. """ @@ -118,8 +119,8 @@ def extend(self, config: Self): """ if self.lang != config.lang: raise RuntimeError( - f"Attempted to join a TOML config with lang {self.lang} with \ - one with lang {config.lang}" + f"Attempted to join a TOML config with lang {self.lang} with" + f"one with lang {config.lang}" ) self.words += config.words diff --git a/language_practice/gui.py b/language_practice/gui.py index d201093..90ac3bb 100644 --- a/language_practice/gui.py +++ b/language_practice/gui.py @@ -6,6 +6,7 @@ # pylint: disable=too-few-public-methods import asyncio +from sqlite3 import IntegrityError import tomllib from typing import Self @@ -62,10 +63,8 @@ def __init__( self.flashcard_set_grid = FlashcardSetGrid() for flashcard_set in flashcard_sets: - delete_button = Gtk.Button(label="Delete") - delete_button.connect("clicked", self.delete_flashcard_set) self.flashcard_set_grid.add_row( - Gtk.CheckButton(), Gtk.Label.new(flashcard_set), delete_button + Gtk.CheckButton(), Gtk.Label.new(flashcard_set) ) scrollable = Gtk.ScrolledWindow() scrollable.set_size_request(700, 600) @@ -75,6 +74,9 @@ def __init__( import_button = Gtk.Button(label="Import") import_button.connect("clicked", self.import_button) button_hbox.append(import_button) + delete_button = Gtk.Button(label="Delete") + delete_button.connect("clicked", self.delete_flashcard_set) + button_hbox.append(delete_button) select_all_button = Gtk.Button(label="Select all") select_all_button.connect("clicked", self.flashcard_set_grid.select_all) button_hbox.append(select_all_button) @@ -96,12 +98,18 @@ def import_button(self, button: Gtk.Button): file_dialog = Gtk.FileDialog() file_dialog.open_multiple(callback=self.handle_files) + # pylint: disable=unused-argument def delete_flashcard_set(self, button: Gtk.Button): """ Handle deleting flashcard set on button press. """ - self.handle.delete_set(button.get_prev_sibling().get_text()) - self.flashcard_set_grid.delete_row(button) + selected = self.flashcard_set_grid.get_selected() + selected.sort(reverse=True, key=lambda info: info[1]) + for text, row in selected: + set_id = self.handle.get_id_from_file_name(text) + if set_id is not None: + self.handle.delete_set(set_id) + self.flashcard_set_grid.delete_row(row) def handle_files(self, dialog: Gtk.FileDialog, task: Gio.Task): """ @@ -111,20 +119,34 @@ def handle_files(self, dialog: Gtk.FileDialog, task: Gio.Task): for current_import in self.imports: try: toml = TomlConfig(current_import) - except tomllib.TOMLDecodeError: + except tomllib.TOMLDecodeError as err: + dialog = Gtk.AlertDialog() + dialog.set_message(f"{current_import}: {err}") + dialog.set_modal(True) + dialog.choose() continue - self.handle.import_set( - current_import, - toml, - asyncio.run(scrape(toml.get_words(), toml.get_lang())), - ) + try: + new = self.handle.import_set( + current_import, + toml, + asyncio.run(scrape(toml.get_words(), toml.get_lang())), + ) + except IntegrityError as err: + dialog = Gtk.AlertDialog() + dialog.set_message(f"{current_import}: {err}") + dialog.set_modal(True) + dialog.choose() + set_id = self.handle.get_id_from_file_name(current_import) + if set_id is not None: + self.handle.delete_set(set_id) + continue - delete_button = Gtk.Button(label="Delete") - delete_button.connect("clicked", self.delete_flashcard_set) - self.flashcard_set_grid.add_row( - Gtk.CheckButton(), Gtk.Label.new(current_import), delete_button - ) + if new: + self.flashcard_set_grid.add_row( + Gtk.CheckButton(), + Gtk.Label.new(current_import), + ) self.imports = [] # pylint: disable=unused-argument @@ -134,11 +156,11 @@ def handle_start(self, button: Gtk.Button): """ files = self.flashcard_set_grid.get_selected() config = None - for file in files: + for text, _ in files: if config is None: - config = self.handle.load_config(file) + config = self.handle.load_config(text) else: - config = config.extend(self.handle.load_config(file)) + config = config.extend(self.handle.load_config(text)) if config is not None: self.flashcard = Flashcard(self.handle, config.get_words()) @@ -167,23 +189,20 @@ def __init__(self, *args, **kwargs): self.set_column_spacing(10) self.num_rows = 0 - def add_row( - self, checkbox: Gtk.CheckButton, label: Gtk.Label, delete_button: Gtk.Button - ): + def add_row(self, checkbox: Gtk.CheckButton, label: Gtk.Label): """ Add a row to the grid. """ self.attach(checkbox, 0, self.num_rows, 1, 1) self.attach(label, 1, self.num_rows, 1, 1) - self.attach(delete_button, 2, self.num_rows, 1, 1) self.num_rows += 1 - def delete_row(self, contains_child: Gtk.Button): + def delete_row(self, row): """ Delete a row from the grid. """ - info = self.query_child(contains_child) - self.remove_row(info.row) + self.remove_row(row) + self.num_rows -= 1 # pylint: disable=unused-argument def select_all(self, button: Gtk.Button): @@ -194,14 +213,14 @@ def select_all(self, button: Gtk.Button): self.get_child_at(0, row).set_active(True) # pylint: disable=unused-argument - def get_selected(self) -> list[str]: + def get_selected(self) -> list[tuple[str, int]]: """ Get all selected flashcard sets. """ files = [] for row in range(self.num_rows): if self.get_child_at(0, row).get_active(): - files.append(self.get_child_at(1, row).get_text()) + files.append((self.get_child_at(1, row).get_text(), row)) return files @@ -213,6 +232,7 @@ class StudyWindow(Gtk.ApplicationWindow): def __init__(self, flashcard: Flashcard, *args, **kwargs): super().__init__(*args, **kwargs) + self.set_title("Language Practice") self.flashcard = flashcard (self.peek, self.is_review) = self.flashcard.current() diff --git a/language_practice/repetition.py b/language_practice/repetition.py index 035eafd..2653f7b 100644 --- a/language_practice/repetition.py +++ b/language_practice/repetition.py @@ -14,6 +14,7 @@ class WordRepetition: DEFAULT_EASYNESS_FACTOR = 2.5 # pylint: disable=too-many-arguments + # pylint: disable=too-many-positional-arguments def __init__( self, easiness_factor: float, @@ -40,7 +41,7 @@ def grade(self, grade: int): self.in_n_days = 6 self.date_of_next = date.today() + timedelta(days=self.in_n_days) else: - self.in_n_days = math.ceil(self.in_n_days * self.easiness_factor) + self.in_n_days = math.floor(self.in_n_days * self.easiness_factor) self.date_of_next = date.today() + timedelta(days=self.in_n_days) self.num_correct += 1 diff --git a/language_practice/sqlite.py b/language_practice/sqlite.py index 3970c22..e44cb22 100644 --- a/language_practice/sqlite.py +++ b/language_practice/sqlite.py @@ -2,6 +2,7 @@ Database code """ +import uuid import sqlite3 from datetime import date @@ -13,18 +14,25 @@ class SqliteHandle: Handler for sqlite operations. """ - FLASHCARDS_SCHEMA = "file_name TEXT PRIMARY KEY, lang TEXT" + FLASHCARDS_TABLE_NAME = "flashcard_sets" + FLASHCARDS_SCHEMA = ( + "id INTEGER PRIMARY KEY AUTOINCREMENT, file_name TEXT, lang TEXT" + ) + WORD_TABLE_NAME = "words" WORD_SCHEMA = ( - "word TEXT PRIMARY KEY, definition TEXT, gender TEXT, aspect TEXT, " - "usage TEXT, part_of_speech TEXT, easiness_factor REAL, num_correct INTEGER, " - "in_n_days INTEGER, date_of_next TEXT, review NUMERIC, file_name TEXT" + "word TEXT PRIMARY KEY NOT NULL, definition TEXT NOT NULL, gender TEXT, " + "aspect TEXT, usage TEXT, part_of_speech TEXT, easiness_factor REAL, " + "num_correct INTEGER, in_n_days INTEGER, date_of_next TEXT, review NUMERIC, " + "flashcard_set_id INTEGER, table_uuids TEXT" ) def __init__(self, db: str): self.conn = sqlite3.connect(db) self.cursor = self.conn.cursor() - self.create_table_idempotent("flashcard_sets", SqliteHandle.FLASHCARDS_SCHEMA) + self.create_table_idempotent( + SqliteHandle.FLASHCARDS_TABLE_NAME, SqliteHandle.FLASHCARDS_SCHEMA + ) def create_table_idempotent(self, name: str, schema: str): """ @@ -57,17 +65,231 @@ def insert_into(self, name: str, columns: str, values: str): """ Insert into table. """ - self.cursor.execute(f"INSERT INTO '{name}' ({columns}) VALUES({values});") + self.cursor.execute( + f"INSERT OR IGNORE INTO '{name}' ({columns}) VALUES({values});" + ) + + def update(self, name: str, set_statements: str, condition: str): + """ + Insert into table. + """ + self.cursor.execute(f"UPDATE '{name}' SET {set_statements} WHERE {condition};") - # pylint: disable=too-many-nested-blocks + def get_id_from_file_name(self, file_name: str) -> int | None: + """ + Checks whether the flashcard set already exists. + """ + table_name = SqliteHandle.FLASHCARDS_TABLE_NAME + res = self.cursor.execute( + f"SELECT id FROM {table_name} WHERE file_name = '{file_name}';" + ) + set_id = res.fetchone() + if set_id is not None: + set_id = set_id[0] + return set_id + + # pylint: disable=too-many-locals + def update_set( + self, set_id: int, config: Config, scraped: dict[str, list[list[list[str]]]] + ): + """ + Update an existing flashcard set. + """ + lang = config.get_lang() + self.update( + SqliteHandle.FLASHCARDS_TABLE_NAME, f"lang = '{lang}'", f"id = {set_id}" + ) + + table_name = SqliteHandle.WORD_TABLE_NAME + res = self.cursor.execute( + f"SELECT word FROM {table_name} WHERE flashcard_set_id = {set_id}" + ) + current_words = set(map(lambda word: word[0], res.fetchall())) + config_word_dct = {entry.get_word(): entry for entry in config} + config_words = set(config_word_dct.keys()) + + words_to_add = config_words - current_words + for word in words_to_add: + self.insert_word(set_id, config_word_dct[word], scraped.get(word, None)) + + words_to_update = current_words & config_words + for word in words_to_update: + self.update_word(config_word_dct[word], scraped.get(word, None)) + + words_to_delete = current_words - config_words + for word in words_to_delete: + self.delete(SqliteHandle.WORD_TABLE_NAME, "word = '{word}'") + res = self.cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table';" + ) + names = res.fetchall() + names_to_drop = [ + name[0] + for name in names + if name[0] != "" and not name[0].isspace() and name[0].startswith(word) + ] + for name in names_to_drop: + self.drop_table(name) + + # pylint: disable=too-many-branches # pylint: disable=too-many-statements + def insert_word( + self, set_id: int, entry: Entry, scraped: list[list[list[str]]] | None + ): + """ + Insert a new word into the table. + """ + word = entry.get_word() + definition = entry.get_definition() + gender = entry.get_gender() + aspect = entry.get_aspect() + usage = entry.get_usage() + part_of_speech = entry.get_part_of_speech() + charts = entry.get_charts() + repetition = entry.get_repetition() + easiness_factor = repetition.get_easiness_factor() + num_correct = repetition.get_num_correct() + in_n_days = repetition.get_in_n_days() + date_of_next = repetition.get_date_of_next() + review = 1 if repetition.get_review() else 0 + if charts is None: + final_charts = scraped + else: + final_charts = [charts] + + table_uuids = [] + if final_charts is not None: + for chart in final_charts: + table_uuid = str(uuid.uuid4()) + table_uuids.append(table_uuid) + max_len = max(map(len, chart)) + schema = ", ".join([f"{chr(i + 97)} TEXT" for i in range(0, max_len)]) + self.recreate_table(f"{table_uuid}", schema) + for row in chart: + columns = [] + values = [] + for j in range(0, max_len): + try: + val = row[j] + val = val.replace("'", "''") + except IndexError: + pass + else: + columns.append(chr(j + 97)) + values.append(f"'{val}'") + self.insert_into(table_uuid, ", ".join(columns), ", ".join(values)) + + word = word.replace("'", "''") + + columns = [ + "word", + "definition", + "easiness_factor", + "num_correct", + "in_n_days", + "date_of_next", + "review", + "flashcard_set_id", + ] + values = [ + f"'{word}'", + f"'{definition}'", + f"{easiness_factor}", + f"{num_correct}", + f"{in_n_days}", + f"'{date_of_next}'", + f"{review}", + f"'{set_id}'", + ] + if gender is not None: + columns.append("gender") + values.append(f"'{gender}'") + if aspect is not None: + columns.append("aspect") + values.append(f"'{aspect}'") + if usage is not None: + columns.append("usage") + values.append(f"'{usage}'") + if part_of_speech is not None: + columns.append("part_of_speech") + values.append(f"'{part_of_speech}'") + if len(table_uuids) > 0: + columns.append("table_uuids") + table_uuid_str = ",".join(table_uuids) + values.append(f"'{table_uuid_str}'") + self.insert_into("words", ", ".join(columns), ", ".join(values)) + # pylint: disable=too-many-branches - # pylint: disable=too-many-locals - def import_set( + def update_word(self, entry: Entry, scraped: list[list[list[str]]] | None): + """ + Update existing word in the table. + """ + word = entry.get_word() + definition = entry.get_definition() + gender = entry.get_gender() + aspect = entry.get_aspect() + usage = entry.get_usage() + part_of_speech = entry.get_part_of_speech() + charts = entry.get_charts() + if charts is None: + final_charts = scraped + else: + final_charts = [charts] + + res = self.cursor.execute(f"SELECT table_uuids FROM words WHERE word='{word}';") + table_uuids = res.fetchone()[0] + if table_uuids is not None: + for table_uuid in table_uuids.split(","): + self.drop_table(table_uuid) + + table_uuids = [] + if final_charts is not None: + table_uuids = [] + for i, chart in enumerate(final_charts): + table_uuid = str(uuid.uuid4()) + table_uuids.append(table_uuid) + max_len = max(map(len, chart)) + schema = ", ".join([f"{chr(i + 97)} TEXT" for i in range(0, max_len)]) + self.recreate_table(f"{table_uuid}", schema) + for row in chart: + columns = [] + values = [] + for j in range(0, max_len): + try: + val = row[j] + val = val.replace("'", "''") + except IndexError: + pass + else: + columns.append(chr(j + 97)) + values.append(f"'{val}'") + self.insert_into(table_uuid, ", ".join(columns), ", ".join(values)) + + set_statements = [ + f"word = '{word}'", + f"definition = '{definition}'", + ] + if gender is not None: + set_statements.append(f"gender = '{gender}'") + if aspect is not None: + set_statements.append(f"aspect = '{aspect}'") + if usage is not None: + set_statements.append(f"usage = '{usage}'") + if part_of_speech is not None: + set_statements.append(f"part_of_speech = '{part_of_speech}'") + if len(table_uuids) > 0: + table_uuid_str = ",".join(table_uuids) + set_statements.append(f"table_uuids = '{table_uuid_str}'") + + self.update( + SqliteHandle.WORD_TABLE_NAME, ", ".join(set_statements), f"word = '{word}'" + ) + + def create_new_set( self, file_name: str, config: Config, scraped: dict[str, list[list[list[str]]]] ): """ - Import set into database. + Create a new flashcard set. """ lang = config.get_lang() columns = ["file_name"] @@ -76,108 +298,52 @@ def import_set( columns.append("lang") values.append(f"'{lang}'") self.insert_into("flashcard_sets", ", ".join(columns), ", ".join(values)) + set_id = self.cursor.lastrowid self.create_table_idempotent( - "words", + SqliteHandle.WORD_TABLE_NAME, SqliteHandle.WORD_SCHEMA, ) - for entry in iter(config): - word = entry.get_word() - definition = entry.get_definition() - gender = entry.get_gender() - aspect = entry.get_aspect() - usage = entry.get_usage() - part_of_speech = entry.get_part_of_speech() - charts = entry.get_charts() - repetition = entry.get_repetition() - easiness_factor = repetition.get_easiness_factor() - num_correct = repetition.get_num_correct() - in_n_days = repetition.get_in_n_days() - date_of_next = repetition.get_date_of_next() - review = 1 if repetition.get_review() else 0 - if charts is None: - charts = scraped.get(entry.get_word(), None) - else: - charts = [charts] - - if charts is not None: - for i, chart in enumerate(charts): - max_len = max(map(len, chart)) - schema = ", ".join( - [f"{chr(i + 97)} TEXT" for i in range(0, max_len)] - ) - self.recreate_table(f"{word}-{i}", schema) - for row in chart: - columns = [] - values = [] - for j in range(0, max_len): - try: - val = row[j] - val = val.replace("'", "''") - except IndexError: - pass - else: - columns.append(chr(j + 97)) - values.append(f"'{val}'") - self.insert_into( - f"{word}-{i}", ", ".join(columns), ", ".join(values) - ) - - columns = [ - "word", - "definition", - "easiness_factor", - "num_correct", - "in_n_days", - "date_of_next", - "review", - "file_name", - ] - values = [ - f"'{word}'", - f"'{definition}'", - f"{easiness_factor}", - f"{num_correct}", - f"{in_n_days}", - f"'{date_of_next}'", - f"{review}", - f"'{file_name}'", - ] - if gender is not None: - columns.append("gender") - values.append(f"'{gender}'") - if aspect is not None: - columns.append("aspect") - values.append(f"'{aspect}'") - if usage is not None: - columns.append("usage") - values.append(f"'{usage}'") - if part_of_speech is not None: - columns.append("part_of_speech") - values.append(f"'{part_of_speech}'") - self.insert_into("words", ", ".join(columns), ", ".join(values)) - - def delete_set(self, file_name: str): + if set_id is not None: + for entry in iter(config): + self.insert_word(set_id, entry, scraped.get(entry.get_word(), None)) + + # pylint: disable=too-many-nested-blocks + # pylint: disable=too-many-statements + # pylint: disable=too-many-branches + # pylint: disable=too-many-locals + def import_set( + self, file_name: str, config: Config, scraped: dict[str, list[list[list[str]]]] + ) -> bool: + """ + Import set into database. + """ + set_id = self.get_id_from_file_name(file_name) + if set_id is None: + self.create_new_set(file_name, config, scraped) + return True + + self.update_set(set_id, config, scraped) + return False + + def delete_set(self, set_id: int): """ Delete a set from the database. """ - words = self.cursor.execute( - f"SELECT word FROM 'words' WHERE file_name = '{file_name}';" - ).fetchall() - res = self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") - names = res.fetchall() - names_to_drop = [ - name[0] for name in names for word in words if name[0].startswith(word) - ] - for name in names_to_drop: - self.drop_table(name) - self.drop_table(file_name) - self.delete("words", f"file_name = '{file_name}'") - self.delete("flashcard_sets", f"file_name = '{file_name}'") + res = self.cursor.execute( + f"SELECT table_uuids FROM 'words' WHERE flashcard_set_id = {set_id};" + ) + for uuids in res: + if uuids[0] is not None: + for table_uuid in uuids[0].split(","): + self.drop_table(table_uuid) + self.delete("words", f"flashcard_set_id = {set_id}") + self.delete("flashcard_sets", f"id = {set_id}") def load_config(self, file_name: str) -> Config: """ Load config from database. """ + set_id = self.get_id_from_file_name(file_name) res = self.cursor.execute( f"SELECT lang FROM flashcard_sets WHERE file_name = '{file_name}';" ) @@ -185,14 +351,11 @@ def load_config(self, file_name: str) -> Config: res = self.cursor.execute( f"SELECT word, definition, gender, aspect, usage, part_of_speech, " - f"easiness_factor, num_correct, in_n_days, date_of_next, review " - f"FROM 'words' WHERE file_name = '{file_name}';" + f"easiness_factor, num_correct, in_n_days, date_of_next, review, " + f"table_uuids FROM 'words' WHERE flashcard_set_id = '{set_id}';" ) entries = res.fetchall() - res = self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") - table_names = res.fetchall() - loaded_entries = [] for entry in entries: ( @@ -207,16 +370,17 @@ def load_config(self, file_name: str) -> Config: in_n_days, date_of_next, review, + table_uuids, ) = entry date_of_next = date.fromisoformat(date_of_next) review = review != 0 charts = [] - names_to_get = [name[0] for name in table_names if name[0].startswith(word)] - for name in names_to_get: - res = self.cursor.execute(f"SELECT * FROM '{name}';") - chart = res.fetchall() - charts.append(chart) + if table_uuids is not None: + for name in table_uuids.split(","): + res = self.cursor.execute(f"SELECT * FROM '{name}';") + chart = res.fetchall() + charts.append(chart) if date.today() >= date_of_next or review: loaded_entries.append( @@ -249,17 +413,19 @@ def update_config(self, word: str, repetition: WordRepetition): in_n_days = repetition.get_in_n_days() date_of_next = str(repetition.get_date_of_next()) review = 1 if repetition.get_review() else 0 - self.cursor.execute( - f"UPDATE words SET easiness_factor = {easiness_factor}, num_correct = " - f"{num_correct}, in_n_days = {in_n_days}, date_of_next = '{date_of_next}', " - f"review = {review} WHERE word = '{word}';" + self.update( + SqliteHandle.WORD_TABLE_NAME, + f"easiness_factor = {easiness_factor}, num_correct = {num_correct}, " + f"in_n_days = {in_n_days}, date_of_next = '{date_of_next}', " + f"review = {review}", + f"word = '{word}'", ) def get_all_sets(self) -> list[str]: """ Get all flashcard sets from database. """ - res = self.cursor.execute("SELECT * FROM flashcard_sets;") + res = self.cursor.execute("SELECT file_name FROM flashcard_sets;") return [entry[0] for entry in res.fetchall()] def close(self): diff --git a/language_practice/terminal.py b/language_practice/terminal.py index 6290b89..2b0b57f 100644 --- a/language_practice/terminal.py +++ b/language_practice/terminal.py @@ -3,6 +3,7 @@ """ import os +from sqlite3 import IntegrityError import tomllib from uuid import uuid4 @@ -14,8 +15,6 @@ ScrollableContainer, Vertical, ) -from textual.css.query import NoMatches -from textual.dom import DOMNode from textual.screen import ModalScreen from textual.widgets import ( Button, @@ -71,10 +70,7 @@ def compose(self): args.append( Horizontal( Checkbox(id=f"select-{hex_string}"), - Container(), Label(flashcard_set), - Container(), - Button("Delete", id=f"delete-{hex_string}"), classes="flashcard-set", ) ) @@ -82,6 +78,7 @@ def compose(self): yield Horizontal( Container(), Button("Import", id="import"), + Button("Delete", id="delete"), Button("Select all", id="select_all"), Button("Start", id="start"), Button("Exit", id="exit"), @@ -89,13 +86,14 @@ def compose(self): classes="bottom-buttons-one", ) - def on_start(self): + async def on_start(self): """ On start button press. """ scrollable = self.query_one("#scrollable") checkboxes = map(lambda child: child.query_one(Checkbox), scrollable.children) config = None + screens = [] for checkbox in checkboxes: if checkbox.value: checkbox_id = checkbox.id @@ -106,19 +104,37 @@ def on_start(self): if config is None: config = self.handle.load_config(name) else: - config = config.extend(self.handle.load_config(name)) + try: + config = config.extend(self.handle.load_config(name)) + except RuntimeError as err: + screens.append( + AlertWindow(f"Failed to use file {name}\n{err}") + ) + continue + if config is not None: self.flashcard = Flashcard(self.handle, config.get_words()) - self.push_screen(StudyScreen(self.flashcard)) + await self.push_screen(StudyScreen(self.flashcard)) + for screen in screens: + await self.push_screen(screen) - def on_delete(self, button_id: str, parent: DOMNode | None): + def on_delete(self): """ On delete button press. """ - name = bytes.fromhex(button_id.split("delete-")[1]).decode("utf-8") - self.handle.delete_set(name) - if parent is not None: - parent.remove() # type: ignore + scrollable = self.query_one("#scrollable") + checkboxes = map(lambda child: child.query_one(Checkbox), scrollable.children) + for checkbox in checkboxes: + if checkbox.value: + checkbox_id = checkbox.id + if checkbox_id is not None: + name = bytes.fromhex(checkbox_id.split("select-")[1]).decode( + "utf-8" + ) + set_id = self.handle.get_id_from_file_name(name) + if set_id is not None: + self.handle.delete_set(set_id) + checkbox.parent.remove() def on_exit_study(self): """ @@ -133,39 +149,41 @@ async def on_complete_import(self): """ On complete import button press. """ + await self.pop_screen() toml = None for import_file in self.imports: try: toml = TomlConfig(import_file) - except tomllib.TOMLDecodeError: - self.pop_screen() + except tomllib.TOMLDecodeError as err: + self.push_screen(AlertWindow(f"{import_file}: {err}")) + return + new = False if toml is not None: - self.handle.import_set( - import_file, - toml, - await scrape(toml.get_words(), toml.get_lang()), - ) + try: + new = self.handle.import_set( + import_file, + toml, + await scrape(toml.get_words(), toml.get_lang()), + ) + except IntegrityError as err: + set_id = self.handle.get_id_from_file_name(import_file) + if set_id is not None: + self.handle.delete_set(set_id) + self.push_screen(AlertWindow(f"{import_file}: {err}")) + continue hex_string = import_file.encode("utf-8").hex() scrollable = self.query_one("#scrollable") - try: - scrollable.query_one(f"#delete-{hex_string}") - except NoMatches: + if new: scrollable.mount( Horizontal( Checkbox(id=f"select-{hex_string}"), - Container(), Label(import_file), - Container(), - Button("Delete", id=f"delete-{hex_string}"), + classes="flashcard-set", ) ) - self.pop_screen() - - self.imports = {} - # pylint: disable=too-many-branches async def on_button_pressed(self, event: Button.Pressed): """ @@ -174,11 +192,12 @@ async def on_button_pressed(self, event: Button.Pressed): button_id = event.button.id if button_id is not None: if button_id == "import": + self.imports = {} self.push_screen(ImportPopup()) elif button_id == "start": - self.on_start() + await self.on_start() elif button_id.startswith("delete"): - self.on_delete(button_id, event.button.parent) + self.on_delete() elif button_id == "exit": self.exit() elif button_id == "exit_study": @@ -394,3 +413,32 @@ async def on_button_pressed(self, event): await self.grade(4) elif event.button.id == "five": await self.grade(5) + + +class AlertWindow(ModalScreen): + """ + Alert window showing an error that occurred. + """ + + def __init__(self, message, *args, **kwargs): + super().__init__(*args, **kwargs) + self.message = message + + def compose(self): + yield Vertical( + Container(), + Horizontal(Container(), Label(self.message), Container()), + Container(), + ) + + def key_enter(self): + """ + Handle enter key press. + """ + self.dismiss() + + def on_click(self): + """ + Handle mouse click. + """ + self.dismiss()