diff --git a/check.bat b/check.bat new file mode 100644 index 0000000..e361979 --- /dev/null +++ b/check.bat @@ -0,0 +1,11 @@ +@echo off +setlocal + +set python=python.exe + +%python% -m mypy pysqlsync || exit /b +%python% -m flake8 pysqlsync || exit /b +%python% -m mypy tests || exit /b +%python% -m flake8 tests || exit /b + +:quit diff --git a/pysqlsync/base.py b/pysqlsync/base.py index 1dcca11..b94200a 100644 --- a/pysqlsync/base.py +++ b/pysqlsync/base.py @@ -26,12 +26,11 @@ Union, overload, ) -from urllib.parse import quote from strong_typing.inspection import DataclassInstance, is_dataclass_type, is_type_enum from strong_typing.name import python_type_to_str -from .connection import ConnectionSSLMode +from .connection import ConnectionParameters from .formation.inspection import get_entity_types from .formation.mutation import Mutator, MutatorOptions from .formation.object_types import ( @@ -531,26 +530,6 @@ def __str__(self) -> str: return f"error executing query:\n{query}" -@dataclass -class ConnectionParameters: - "Database connection parameters that would typically be encapsulated in a connection string." - - host: Optional[str] = None - port: Optional[int] = None - username: Optional[str] = None - password: Optional[str] = None - database: Optional[str] = None - ssl: Optional[ConnectionSSLMode] = None - - def __str__(self) -> str: - host = self.host or "localhost" - port = f":{self.port}" if self.port else "" - username = f"{quote(self.username, safe='')}@" if self.username else "" - database = f"/{quote(self.database, safe='')}" if self.database else "" - ssl = f"?ssl={self.ssl}" if self.ssl else "" - return f"{username}{host}{port}{database}{ssl}" - - class BaseConnection(abc.ABC): "An active connection to a database." diff --git a/pysqlsync/connection.py b/pysqlsync/connection.py index 4f233d4..bc58478 100644 --- a/pysqlsync/connection.py +++ b/pysqlsync/connection.py @@ -9,7 +9,9 @@ import enum import ssl import sys +from dataclasses import dataclass from typing import Optional +from urllib.parse import quote if sys.version_info >= (3, 10): import truststore @@ -41,6 +43,8 @@ class ConnectionSSLMode(enum.Enum): def create_context(ssl_mode: ConnectionSSLMode) -> Optional[ssl.SSLContext]: + "Creates an SSL context to pass to a database connection object." + if ssl_mode is None or ssl_mode is ConnectionSSLMode.disable: return None elif ( @@ -74,3 +78,23 @@ def create_context(ssl_mode: ConnectionSSLMode) -> Optional[ssl.SSLContext]: return ctx else: raise ValueError(f"unsupported SSL mode: {ssl_mode}") + + +@dataclass +class ConnectionParameters: + "Database connection parameters that would typically be encapsulated in a connection string." + + host: Optional[str] = None + port: Optional[int] = None + username: Optional[str] = None + password: Optional[str] = None + database: Optional[str] = None + ssl: Optional[ConnectionSSLMode] = None + + def __str__(self) -> str: + host = self.host or "localhost" + port = f":{self.port}" if self.port else "" + username = f"{quote(self.username, safe='')}@" if self.username else "" + database = f"/{quote(self.database, safe='')}" if self.database else "" + ssl = f"?ssl={self.ssl}" if self.ssl else "" + return f"{username}{host}{port}{database}{ssl}" diff --git a/pysqlsync/factory.py b/pysqlsync/factory.py index 4f19169..e677467 100644 --- a/pysqlsync/factory.py +++ b/pysqlsync/factory.py @@ -15,13 +15,8 @@ from strong_typing.inspection import get_module_classes -from .base import ( - BaseConnection, - BaseEngine, - BaseGenerator, - ConnectionParameters, - Explorer, -) +from .base import BaseConnection, BaseEngine, BaseGenerator, Explorer +from .connection import ConnectionParameters LOGGER = logging.getLogger("pysqlsync") diff --git a/setup.cfg b/setup.cfg index 56bde3e..87860ab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,7 +29,7 @@ include_package_data = True packages = find: python_requires = >=3.9 install_requires = - certify >= 2024.8.30 + certifi >= 2024.8.30 truststore >= 0.9; python_version>="3.10" json_strong_typing >= 0.3.2 typing_extensions >= 4.8; python_version<"3.12" diff --git a/tests/example.py b/tests/example.py index 7307f9d..fd614bc 100644 --- a/tests/example.py +++ b/tests/example.py @@ -10,7 +10,8 @@ from strong_typing.auxiliary import Annotated, MaxLength -from pysqlsync.base import ConnectionParameters, GeneratorOptions +from pysqlsync.base import GeneratorOptions +from pysqlsync.connection import ConnectionParameters from pysqlsync.factory import get_dialect from pysqlsync.formation.py_to_sql import EnumMode from pysqlsync.model.id_types import LocalId diff --git a/tests/params.py b/tests/params.py index 2191433..e7af508 100644 --- a/tests/params.py +++ b/tests/params.py @@ -3,7 +3,8 @@ import os import os.path -from pysqlsync.base import BaseEngine, ConnectionParameters, ConnectionSSLMode +from pysqlsync.base import BaseEngine +from pysqlsync.connection import ConnectionParameters, ConnectionSSLMode from pysqlsync.factory import get_dialect diff --git a/tests/test_api.py b/tests/test_api.py index 94805e2..497afec 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,7 +1,7 @@ import unittest from urllib.parse import quote -from pysqlsync.base import ConnectionParameters +from pysqlsync.connection import ConnectionParameters from pysqlsync.factory import get_dialect, get_parameters