Skip to content

Commit

Permalink
Convert column of varchar into foreign key to extensible enumeratio…
Browse files Browse the repository at this point in the history
…n type table
  • Loading branch information
hunyadi committed Oct 28, 2024
1 parent 48ac6fd commit cad5858
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 16 deletions.
4 changes: 3 additions & 1 deletion pysqlsync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,9 @@ def _get_dataclass_extractors(
"Returns a tuple of callable function objects that extracts each field of a data-class."

return tuple(
self.get_field_extractor(table.columns[field.name], field.name, field.type)
self.get_field_extractor(
table.columns[field.name], field.name, typing.cast(type, field.type)
)
for field in dataclasses.fields(entity_type)
if not (skip_identity and table.columns[field.name].identity)
)
Expand Down
7 changes: 6 additions & 1 deletion pysqlsync/dialect/mysql/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
deleted,
join_or_none,
)
from pysqlsync.model.data_types import SqlEnumType, quote
from pysqlsync.model.data_types import SqlEnumType, SqlVariableCharacterType, quote
from pysqlsync.model.id_types import LocalId

from .object_types import MySQLColumn, MySQLTable
Expand All @@ -27,6 +27,11 @@ def migrate_column_stmt(
f'INSERT INTO {ref.table} ("value") VALUES {enum_values}\n'
'ON DUPLICATE KEY UPDATE "value" = "value";'
)
elif isinstance(source.data_type, SqlVariableCharacterType):
statements.append(
f'INSERT INTO {ref.table} ("value") SELECT DISTINCT {LocalId(deleted(source.name.id))} FROM {source_table.name}\n'
'ON DUPLICATE KEY UPDATE "value" = "value";'
)
statements.append(
f"UPDATE {source_table.name} data_table\n"
f'JOIN {ref.table} enum_table ON data_table.{LocalId(deleted(source.name.id))} = enum_table."value"\n'
Expand Down
16 changes: 15 additions & 1 deletion pysqlsync/dialect/postgresql/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,27 @@
deleted,
join_or_none,
)
from pysqlsync.model.data_types import quote
from pysqlsync.model.data_types import SqlIntegerType, SqlUserDefinedType, quote
from pysqlsync.model.id_types import LocalId
from pysqlsync.util.typing import override

from .object_types import sql_quoted_string


class PostgreSQLMutator(Mutator):
@override
def is_column_migrated(self, source: Column, target: Column) -> Optional[bool]:
is_migrated = super().is_column_migrated(source, target)
if is_migrated is not None:
return is_migrated

if isinstance(target.data_type, SqlIntegerType):
# PostgreSQL defines a separate type for each enumeration type
if isinstance(source.data_type, SqlUserDefinedType):
return True

return None # undecided (converts to False in a Boolean expression)

def migrate_enum_stmt(self, enum_type: EnumType, table: Table) -> Optional[str]:
enum_values = ", ".join(f"({quote(v)})" for v in enum_type.values)
return f'INSERT INTO {table.name} ("value") VALUES {enum_values} ON CONFLICT ("value") DO NOTHING;'
Expand Down
18 changes: 17 additions & 1 deletion pysqlsync/dialect/snowflake/object_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional

from pysqlsync.formation.object_types import Column, ObjectFactory, Table
from pysqlsync.model.data_types import SqlTimestampType
from pysqlsync.model.id_types import LocalId

_sql_quoted_str_table = str.maketrans(
Expand Down Expand Up @@ -43,10 +44,25 @@ def primary_key_constraint_id(self) -> LocalId:


class SnowflakeColumn(Column):
@property
def default_expr(self) -> str:
if self.default is None:
raise ValueError("default value is NULL")

if isinstance(self.data_type, SqlTimestampType):
m = re.match(
r"^'(?P<year>\d{4})-(?P<month>\d{2})-(?P<day>\d{2}) (?P<hour>\d{2}):(?P<minute>\d{2}):(?P<second>\d{2})'$",
self.default,
)
if m:
return f"TIMESTAMP {self.default}"

return self.default

@property
def data_spec(self) -> str:
nullable = " NOT NULL" if not self.nullable and not self.identity else ""
default = f" DEFAULT {self.default}" if self.default is not None else ""
default = f" DEFAULT {self.default_expr}" if self.default is not None else ""
identity = " IDENTITY" if self.identity else ""
description = f" COMMENT {self.comment}" if self.description is not None else ""
return f"{self.data_type}{nullable}{default}{identity}{description}"
Expand Down
36 changes: 25 additions & 11 deletions pysqlsync/formation/mutation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from dataclasses import dataclass
from typing import Optional

from ..model.data_types import SqlEnumType, SqlUserDefinedType, constant, quote
from ..model.data_types import (
SqlEnumType,
SqlIntegerType,
SqlVariableCharacterType,
constant,
quote,
)
from ..model.id_types import SupportsName
from .object_types import (
Catalog,
Expand Down Expand Up @@ -101,20 +107,28 @@ def migrate_column_stmt(
) -> Optional[str]:
return None

def is_column_migrated(self, source: Column, target: Column) -> bool:
def is_column_migrated(self, source: Column, target: Column) -> Optional[bool]:
"""
True if the column requires data migration, false if no migration is needed.
:param source: The source column to convert data from.
:param target: The target column to convert data to.
:returns: True/False, or None if subclass method is to decide.
"""

# no migration is needed if column data type is unchanged
if source == target or source.data_type == target.data_type:
return False

is_user_enum_source = isinstance(
source.data_type, (SqlEnumType, SqlUserDefinedType)
)
is_user_enum_target = isinstance(
target.data_type, (SqlEnumType, SqlUserDefinedType)
)
if is_user_enum_source and not is_user_enum_target:
return True
# check if target data type is a primary key type
if isinstance(target.data_type, SqlIntegerType):
if isinstance(source.data_type, SqlEnumType):
return True
# some database engines represent enumerations as `varchar`
elif isinstance(source.data_type, SqlVariableCharacterType):
return True

return False
return None # undecided (converts to False in a Boolean expression)

def mutate_column_stmt(self, source: Column, target: Column) -> Optional[str]:
if source == target:
Expand Down
2 changes: 1 addition & 1 deletion pysqlsync/model/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def constant(v: Any) -> str:
timestamp = v.astimezone(tz=datetime.timezone.utc).replace(tzinfo=None)
else:
timestamp = v
return f"TIMESTAMP {quote(timestamp.isoformat(sep=' '))}"
return quote(timestamp.isoformat(sep=" "))
elif isinstance(v, tuple):
values = ", ".join(constant(value) for value in v)
return f"({values})"
Expand Down

0 comments on commit cad5858

Please sign in to comment.