From 68ad8318eb5f584ab15e1ef97f6b294eea9caa84 Mon Sep 17 00:00:00 2001 From: Levente Hunyadi Date: Tue, 10 Sep 2024 13:45:52 +0200 Subject: [PATCH] Fix miscellaneous issues with Microsoft SQL Server --- pysqlsync/dialect/mssql/connection.py | 2 + pysqlsync/dialect/mssql/mutation.py | 6 ++- pysqlsync/dialect/mssql/object_types.py | 50 +++++++++++++------------ pysqlsync/formation/object_types.py | 4 ++ tests/test_generator.py | 22 ++++------- 5 files changed, 44 insertions(+), 40 deletions(-) diff --git a/pysqlsync/dialect/mssql/connection.py b/pysqlsync/dialect/mssql/connection.py index 3e2b594..7139b65 100644 --- a/pysqlsync/dialect/mssql/connection.py +++ b/pysqlsync/dialect/mssql/connection.py @@ -42,6 +42,8 @@ def open(self) -> BaseContext: "PWD": self.params.password, "TrustServerCertificate": "yes", } + if self.params.database is not None: + params["DATABASE"] = self.params.database conn_string = ";".join( f"{key}={value}" for key, value in params.items() if value is not None ) diff --git a/pysqlsync/dialect/mssql/mutation.py b/pysqlsync/dialect/mssql/mutation.py index c9a4075..5f49fe8 100644 --- a/pysqlsync/dialect/mssql/mutation.py +++ b/pysqlsync/dialect/mssql/mutation.py @@ -1,3 +1,4 @@ +import typing from typing import Optional from pysqlsync.formation.mutation import Mutator @@ -8,7 +9,8 @@ Table, join_or_none, ) -from pysqlsync.model.id_types import LocalId + +from .object_types import MSSQLColumn class MSSQLMutator(Mutator): @@ -37,7 +39,7 @@ def mutate_table_stmt(self, source: Table, target: Table) -> Optional[str]: if source_def == target_def: continue - name = LocalId(f"df_{source_column.name.local_id}") + name = typing.cast(MSSQLColumn, source_column).default_constraint_name() if source_def is not None: constraints.append(f"DROP CONSTRAINT {name}") if target_def is not None: diff --git a/pysqlsync/dialect/mssql/object_types.py b/pysqlsync/dialect/mssql/object_types.py index 156d4b0..9a27b25 100644 --- a/pysqlsync/dialect/mssql/object_types.py +++ b/pysqlsync/dialect/mssql/object_types.py @@ -1,3 +1,4 @@ +import typing from typing import Optional from pysqlsync.formation.object_types import ( @@ -10,14 +11,38 @@ ) from pysqlsync.model.data_types import quote from pysqlsync.model.id_types import LocalId +from pysqlsync.util.typing import override class MSSQLColumn(Column): + def default_constraint_name(self) -> LocalId: + "The name of the constraint for DEFAULT." + + return LocalId(f"df_{self.name.local_id}") + @property def data_spec(self) -> str: nullable = " NOT NULL" if not self.nullable else "" + name = self.default_constraint_name() + default = ( + f" CONSTRAINT {name} DEFAULT {self.default}" + if self.default is not None + else "" + ) identity = " IDENTITY" if self.identity else "" - return f"{self.data_type}{nullable}{identity}" + return f"{self.data_type}{nullable}{default}{identity}" + + @override + def create_stmt(self) -> str: + return f"ADD {self.column_spec}" + + @override + def drop_stmt(self) -> str: + if self.default is not None: + name = self.default_constraint_name() + return f"DROP CONSTRAINT {name}, COLUMN {self.name}" + else: + return f"DROP COLUMN {self.name}" class MSSQLTable(Table): @@ -28,39 +53,16 @@ def alter_table_stmt(self, statements: list[str]) -> str: def add_constraints_stmt(self) -> Optional[str]: statements: list[str] = [] - - constraints: list[str] = [] - for column in self.columns.values(): - if column.default is None: - continue - name = LocalId(f"df_{column.name.local_id}") - constraints.append( - f"ADD CONSTRAINT {name} DEFAULT {column.default} FOR {column.name}" - ) - if constraints: - statements.append(self.alter_table_stmt(constraints)) - if self.table_constraints: statements.append( f"ALTER TABLE {self.name} ADD\n" + ",\n".join(f"CONSTRAINT {c.spec}" for c in self.table_constraints) + ";" ) - return join_or_none(statements) def drop_constraints_stmt(self) -> Optional[str]: statements: list[str] = [] - - constraints: list[str] = [] - for column in self.columns.values(): - if column.default is None: - continue - name = LocalId(f"df_{column.name.local_id}") - constraints.append(f"DROP CONSTRAINT {name}") - if constraints: - statements.append(self.alter_table_stmt(constraints)) - if self.table_constraints: statements.append( f"ALTER TABLE {self.name} DROP\n" diff --git a/pysqlsync/formation/object_types.py b/pysqlsync/formation/object_types.py index 401651e..576c13d 100644 --- a/pysqlsync/formation/object_types.py +++ b/pysqlsync/formation/object_types.py @@ -203,9 +203,13 @@ def data_spec(self) -> str: return f"{self.data_type}{nullable}{default}{identity}" def create_stmt(self) -> str: + "Creates a column as part of an ALTER TABLE statement." + return f"ADD COLUMN {self.column_spec}" def drop_stmt(self) -> str: + "Removes a column as part of an ALTER TABLE statement." + return f"DROP COLUMN {self.name}" def soft_drop_stmt(self) -> str: diff --git a/tests/test_generator.py b/tests/test_generator.py index 405d14c..4b44138 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -130,18 +130,13 @@ def test_create_default_numeric_table(self) -> None: tables.DefaultNumericTable, 'CREATE TABLE "DefaultNumericTable" (\n' '"id" bigint NOT NULL,\n' - '"integer_8" smallint NOT NULL,\n' - '"integer_16" smallint NOT NULL,\n' - '"integer_32" integer NOT NULL,\n' - '"integer_64" bigint NOT NULL,\n' - '"integer" bigint NOT NULL,\n' + '"integer_8" smallint NOT NULL CONSTRAINT "df_integer_8" DEFAULT 127,\n' + '"integer_16" smallint NOT NULL CONSTRAINT "df_integer_16" DEFAULT 32767,\n' + '"integer_32" integer NOT NULL CONSTRAINT "df_integer_32" DEFAULT 2147483647,\n' + '"integer_64" bigint NOT NULL CONSTRAINT "df_integer_64" DEFAULT 0,\n' + '"integer" bigint NOT NULL CONSTRAINT "df_integer" DEFAULT 23,\n' 'CONSTRAINT "pk_DefaultNumericTable" PRIMARY KEY ("id")\n' - ");\n" - 'ALTER TABLE "DefaultNumericTable" ADD CONSTRAINT "df_integer_8" DEFAULT 127 FOR "integer_8";\n' - 'ALTER TABLE "DefaultNumericTable" ADD CONSTRAINT "df_integer_16" DEFAULT 32767 FOR "integer_16";\n' - 'ALTER TABLE "DefaultNumericTable" ADD CONSTRAINT "df_integer_32" DEFAULT 2147483647 FOR "integer_32";\n' - 'ALTER TABLE "DefaultNumericTable" ADD CONSTRAINT "df_integer_64" DEFAULT 0 FOR "integer_64";\n' - 'ALTER TABLE "DefaultNumericTable" ADD CONSTRAINT "df_integer" DEFAULT 23 FOR "integer";', + ");", ) def test_create_fixed_precision_float_table(self) -> None: @@ -291,10 +286,9 @@ def test_create_default_datetime_table(self) -> None: tables.DefaultDateTimeTable, 'CREATE TABLE "DefaultDateTimeTable" (\n' '"id" bigint NOT NULL,\n' - """"iso_date_time" datetime2 NOT NULL,\n""" + """"iso_date_time" datetime2 NOT NULL CONSTRAINT "df_iso_date_time" DEFAULT '1989-10-24 23:59:59',\n""" 'CONSTRAINT "pk_DefaultDateTimeTable" PRIMARY KEY ("id")\n' - ");\n" - """ALTER TABLE "DefaultDateTimeTable" ADD CONSTRAINT "df_iso_date_time" DEFAULT '1989-10-24 23:59:59' FOR "iso_date_time";""", + ");", ) def test_create_enum_table(self) -> None: