diff --git a/.github/workflows/edgetest.yml b/.github/workflows/edgetest.yml index f699c1f..a690653 100644 --- a/.github/workflows/edgetest.yml +++ b/.github/workflows/edgetest.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest name: running edgetest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: ref: develop - name: Copy files for locopy diff --git a/.github/workflows/publish-docs.yml b/.github/workflows/publish-docs.yml index 82f07c3..a801a78 100644 --- a/.github/workflows/publish-docs.yml +++ b/.github/workflows/publish-docs.yml @@ -12,14 +12,14 @@ jobs: build-and-deploy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: # fetch all tags so `versioneer` can properly determine current version fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: - python-version: '3.8' + python-version: '3.9' - name: Install dependencies run: python -m pip install .[dev] - name: Build diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 51fbb67..d67fc8e 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -11,12 +11,12 @@ jobs: deploy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: # fetch all tags so `versioneer` can properly determine current version fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: '3.10' - name: Install dependencies diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index abf94e8..fa50759 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -10,17 +10,32 @@ on: branches: [develop, main] jobs: + lint-and-format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: "3.9" + - name: Install dependencies + run: python -m pip install .[qa] + - name: Linting by ruff + run: ruff check + - name: Formatting by ruff + run: ruff format --check + build: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.9, '3.10', '3.11'] + python-version: [3.9, '3.10', '3.11'] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 680cefa..d9fa300 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,25 @@ repos: - - repo: https://github.com/psf/black - rev: 22.6.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.7 hooks: - - id: black - types: [file, python] - language_version: python3.9 - - repo: https://github.com/pre-commit/mirrors-isort - rev: v5.7.0 + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format + types_or: [ python, jupyter ] + # # Mypy: Optional static type checking + # # https://github.com/pre-commit/mirrors-mypy + # - repo: https://github.com/pre-commit/mirrors-mypy + # rev: v1.11.1 + # hooks: + # - id: mypy + # exclude: ^(docs|tests)\/ + # language_version: python3.9 + # args: [--namespace-packages, --explicit-package-bases, --ignore-missing-imports, --non-interactive, --install-types] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 hooks: - - id: isort + - id: trailing-whitespace + - id: debug-statements + - id: end-of-file-fixer diff --git a/CODEOWNERS b/CODEOWNERS index 53ed855..a61d6a5 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1 +1 @@ -* @fdosani @ak-gupta @jdawang @gladysteh99 @NikhilJArora +* @fdosani @ak-gupta @jdawang @gladysteh99 diff --git a/README.rst b/README.rst index c45e50a..1e73b24 100644 --- a/README.rst +++ b/README.rst @@ -13,7 +13,7 @@ A Python library to assist with ETL processing for: In addition: -- The library supports Python 3.8 to 3.10 +- The library supports Python 3.9 to 3.11 - DB Driver (Adapter) agnostic. Use your favourite driver that complies with `DB-API 2.0 `_ - It provides functionality to download and upload data to S3 buckets, and internal stages (Snowflake) diff --git a/docs/source/conf.py b/docs/source/conf.py index 6e9f150..4315309 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # SPDX-Copyright: Copyright (c) Capital One Services, LLC # SPDX-License-Identifier: Apache-2.0 # Copyright 2018 Capital One Services, LLC @@ -27,8 +26,9 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # -import sys import os +import sys + import sphinx_rtd_theme from locopy._version import __version__ @@ -151,7 +151,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, "locopy.tex", "locopy Documentation", "Faisal Dosani, Ian Robertson", "manual") + ( + master_doc, + "locopy.tex", + "locopy Documentation", + "Faisal Dosani, Ian Robertson", + "manual", + ) ] @@ -185,4 +191,9 @@ intersphinx_mapping = {"https://docs.python.org/": None} # autodoc -autodoc_default_flags = ["members", "undoc-members", "show-inheritance", "inherited-members"] +autodoc_default_flags = [ + "members", + "undoc-members", + "show-inheritance", + "inherited-members", +] diff --git a/locopy/__init__.py b/locopy/__init__.py index 091f07d..cf7c376 100644 --- a/locopy/__init__.py +++ b/locopy/__init__.py @@ -13,8 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""A Python library to assist with ETL processing.""" -from .database import Database -from .redshift import Redshift -from .s3 import S3 -from .snowflake import Snowflake +from locopy.database import Database +from locopy.redshift import Redshift +from locopy.s3 import S3 +from locopy.snowflake import Snowflake + +__all__ = ["S3", "Database", "Redshift", "Snowflake"] diff --git a/locopy/_version.py b/locopy/_version.py index 5121c9b..c8554de 100644 --- a/locopy/_version.py +++ b/locopy/_version.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.5.9" +__version__ = "0.6.0" diff --git a/locopy/database.py b/locopy/database.py index 8564331..411d095 100644 --- a/locopy/database.py +++ b/locopy/database.py @@ -14,20 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Database Module -""" +"""Database Module.""" + import time -from .errors import CredentialsError, DBError -from .logger import INFO, get_logger -from .utility import read_config_yaml +from locopy.errors import CredentialsError, DBError +from locopy.logger import INFO, get_logger +from locopy.utility import read_config_yaml logger = get_logger(__name__, INFO) -class Database(object): - """This is the base class for all DBAPI 2 database connectors which will inherit this - functionality. The ``Database`` class will manage connections and handle executing queries. +class Database: + """Base class for all DBAPI 2 database connectors which will inherit this functionality. + + The ``Database`` class will manage connections and handle executing queries. Most of the functionality should work out of the box for classes which inherit minus the abstract method for ``connect`` which may vary across databases. @@ -72,13 +73,16 @@ def __init__(self, dbapi, config_yaml=None, **kwargs): self.cursor = None if config_yaml and self.connection: - raise CredentialsError("Please provide kwargs or a YAML configuraton, not both.") + raise CredentialsError( + "Please provide kwargs or a YAML configuraton, not both." + ) if config_yaml: self.connection = read_config_yaml(config_yaml) def connect(self): - """Creates a connection to a database by setting the values of the ``conn`` and ``cursor`` - attributes. + """Create a connection to a database. + + Sets the values of the ``conn`` and ``cursor`` attributes. Raises ------ @@ -90,11 +94,12 @@ def connect(self): self.cursor = self.conn.cursor() except Exception as e: logger.error("Error connecting to the database. err: %s", e) - raise DBError("Error connecting to the database.") + raise DBError("Error connecting to the database.") from e def disconnect(self): - """Terminates the connection by closing the values of the ``conn`` and ``cursor`` - attributes. + """Terminate the connection. + + Closes the values of the ``conn`` and ``cursor`` attributes. Raises ------ @@ -108,7 +113,9 @@ def disconnect(self): self.conn.close() except Exception as e: logger.error("Error disconnecting from the database. err: %s", e) - raise DBError("There is a problem disconnecting from the database.") + raise DBError( + "There is a problem disconnecting from the database." + ) from e else: logger.info("No connection to close") @@ -153,7 +160,7 @@ def execute(self, sql, commit=True, params=(), many=False, verbose=True): self.cursor.execute(sql, params) except Exception as e: logger.error("Error running SQL query. err: %s", e) - raise DBError("Error running SQL query.") + raise DBError("Error running SQL query.") from e if commit: self.conn.commit() elapsed = time.time() - start_time @@ -167,8 +174,9 @@ def execute(self, sql, commit=True, params=(), many=False, verbose=True): raise DBError("Cannot execute SQL on a closed connection.") def column_names(self): - """Pull column names out of the cursor description. Depending on the - DBAPI, it could return column names as bytes: ``b'column_name'`` + """Pull column names out of the cursor description. + + Depending on the DBAPI, it could return column names as bytes: ``b'column_name'``. Returns ------- @@ -177,12 +185,13 @@ def column_names(self): """ try: return [column[0].decode().lower() for column in self.cursor.description] - except: + except Exception: return [column[0].lower() for column in self.cursor.description] def to_dataframe(self, size=None): - """Return a dataframe of the last query results. This imports Pandas - in here, so that it's not needed for other use cases. This is just a + """Return a dataframe of the last query results. + + This imports Pandas in here, so that it's not needed for other use cases. This is just a convenience method. Parameters @@ -214,7 +223,7 @@ def to_dataframe(self, size=None): return pandas.DataFrame(fetched, columns=columns) def to_dict(self): - """Generate dictionaries of rows + """Generate dictionaries of rows. Yields ------ @@ -226,7 +235,7 @@ def to_dict(self): yield dict(zip(columns, row)) def _is_connected(self): - """Checks the connection and cursor class arrtribues are initalized. + """Check the connection and cursor class arrtributes are initalized. Returns ------- @@ -235,16 +244,18 @@ def _is_connected(self): """ try: return self.conn is not None and self.cursor is not None - except: + except Exception: return False def __enter__(self): + """Open the connection.""" logger.info("Connecting...") self.connect() logger.info("Connection established.") return self def __exit__(self, exc_type, exc, exc_tb): + """Close the connection.""" logger.info("Closing connection...") self.disconnect() logger.info("Connection closed.") diff --git a/locopy/errors.py b/locopy/errors.py index ed1f704..692d352 100644 --- a/locopy/errors.py +++ b/locopy/errors.py @@ -13,81 +13,56 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Custom errors for locopy.""" class LocopyError(Exception): - """ - Baseclass for all Locopy errors. - """ + """Baseclass for all Locopy errors.""" class CompressionError(LocopyError): - """ - Raised when there is an error compressing a file. - """ + """Raised when there is an error compressing a file.""" class LocopySplitError(LocopyError): - """ - Raised when there is an error splitting a file. - """ + """Raised when there is an error splitting a file.""" class LocopyIgnoreHeaderError(LocopyError): - """ - Raised when Multiple IGNOREHEADERS are found in copy options. - """ + """Raised when Multiple IGNOREHEADERS are found in copy options.""" class LocopyConcatError(LocopyError): - """ - Raised when there is an error concatenating files. - """ + """Raised when there is an error concatenating files.""" class DBError(Exception): - """ - Base class for all Database errors. - """ + """Base class for all Database errors.""" class CredentialsError(DBError): - """ - Raised when the users credentials are not provided. - """ + """Raised when the users credentials are not provided.""" class S3Error(Exception): - """ - Base class for all S3 errors. - """ + """Base class for all S3 errors.""" class S3CredentialsError(S3Error): - """ - Raised when there is an error with AWS credentials. - """ + """Raised when there is an error with AWS credentials.""" class S3InitializationError(S3Error): - """ - Raised when there is an error initializing S3 client. - """ + """Raised when there is an error initializing S3 client.""" class S3UploadError(S3Error): - """ - Raised when there is an upload error to S3. - """ + """Raised when there is an upload error to S3.""" class S3DownloadError(S3Error): - """ - Raised when there is an download error to S3. - """ + """Raised when there is an download error to S3.""" class S3DeletionError(S3Error): - """ - Raised when there is an deletion error on S3. - """ + """Raised when there is an deletion error on S3.""" diff --git a/locopy/logger.py b/locopy/logger.py index 2614124..c07beb4 100644 --- a/locopy/logger.py +++ b/locopy/logger.py @@ -14,35 +14,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Logging Module -Module which setsup the basic logging infrustrcuture for the application +"""Logging Module. + +Module which sets up the basic logging infrustrcuture for the application. """ + import logging import sys # logger formating BRIEF_FORMAT = "%(levelname)s %(asctime)s - %(name)s: %(message)s" VERBOSE_FORMAT = ( - "%(levelname)s|%(asctime)s|%(name)s|%(filename)s|" "%(funcName)s|%(lineno)d: %(message)s" + "%(levelname)s|%(asctime)s|%(name)s|%(filename)s|" + "%(funcName)s|%(lineno)d: %(message)s" ) FORMAT_TO_USE = VERBOSE_FORMAT # logger levels DEBUG = logging.DEBUG INFO = logging.INFO -WARN = logging.WARN +WARN = logging.WARNING ERROR = logging.ERROR CRITICAL = logging.CRITICAL def get_logger(name=None, log_level=logging.DEBUG): - """Sets the basic logging features for the application + """Set the basic logging features for the application. + Parameters ---------- name : str, optional The name of the logger. Defaults to ``None`` log_level : int, optional The logging level. Defaults to ``logging.INFO`` + Returns ------- logging.Logger diff --git a/locopy/redshift.py b/locopy/redshift.py index eeabead..84527f4 100644 --- a/locopy/redshift.py +++ b/locopy/redshift.py @@ -14,26 +14,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Redshift Module +"""Redshift Module. + Module to wrap a database adapter into a Redshift class which can be used to connect to Redshift, and run arbitrary code. """ + import os from pathlib import Path -from .database import Database -from .errors import DBError, S3CredentialsError -from .logger import INFO, get_logger -from .s3 import S3 -from .utility import (compress_file_list, concatenate_files, find_column_type, - get_ignoreheader_number, split_file, write_file) +from locopy.database import Database +from locopy.errors import DBError, S3CredentialsError +from locopy.logger import INFO, get_logger +from locopy.s3 import S3 +from locopy.utility import ( + compress_file_list, + concatenate_files, + find_column_type, + get_ignoreheader_number, + split_file, + write_file, +) logger = get_logger(__name__, INFO) def add_default_copy_options(copy_options=None): - """Adds in default options for the ``COPY`` job, unless those specific - options have been provided in the request. + """Add in default options for the ``COPY`` job. + + Unless those specific options have been provided in the request. Parameters ---------- @@ -58,8 +67,9 @@ def add_default_copy_options(copy_options=None): def combine_copy_options(copy_options): - """Returns the ``copy_options`` attribute with spaces in between and as - a string. + """Return the ``copy_options`` attribute with spaces in between. + + Converts to a string. Parameters ---------- @@ -76,8 +86,9 @@ def combine_copy_options(copy_options): class Redshift(S3, Database): - """Locopy class which manages connections to Redshift. Inherits ``Database`` and implements the - specific ``COPY`` and ``UNLOAD`` functionality. + """Locopy class which manages connections to Redshift. + + Inherits ``Database`` and implements the specific ``COPY`` and ``UNLOAD`` functionality. If any of host, port, dbname, user and password are not provided, a config_yaml file must be provided with those parameters in it. Please note ssl is always enforced when connecting. @@ -159,8 +170,9 @@ def __init__( Database.__init__(self, dbapi, config_yaml, **kwargs) def connect(self): - """Creates a connection to the Redshift cluster by - setting the values of the ``conn`` and ``cursor`` attributes. + """Create a connection to the Redshift cluster. + + Sets the values of the ``conn`` and ``cursor`` attributes. Raises ------ @@ -171,11 +183,10 @@ def connect(self): self.connection["sslmode"] = "require" elif self.dbapi.__name__ == "pg8000": self.connection["ssl_context"] = True - super(Redshift, self).connect() + super().connect() def copy(self, table_name, s3path, delim="|", copy_options=None): - """Executes the COPY command to load files from S3 into - a Redshift table. + """Execute the COPY command to load files from S3 into a Redshift table. Parameters ---------- @@ -200,10 +211,10 @@ def copy(self, table_name, s3path, delim="|", copy_options=None): """ if not self._is_connected(): raise DBError("No Redshift connection object is present.") - if copy_options and "PARQUET" not in copy_options or copy_options is None: + if (copy_options and "PARQUET" not in copy_options) or copy_options is None: copy_options = add_default_copy_options(copy_options) if delim: - copy_options = [f"DELIMITER '{delim}'"] + copy_options + copy_options = [f"DELIMITER '{delim}'", *copy_options] copy_options_text = combine_copy_options(copy_options) base_copy_string = "COPY {0} FROM '{1}' " "CREDENTIALS '{2}' " "{3};" try: @@ -214,7 +225,7 @@ def copy(self, table_name, s3path, delim="|", copy_options=None): except Exception as e: logger.error("Error running COPY on Redshift. err: %s", e) - raise DBError("Error running COPY on Redshift.") + raise DBError("Error running COPY on Redshift.") from e def load_and_copy( self, @@ -228,8 +239,9 @@ def load_and_copy( compress=True, s3_folder=None, ): - """Loads a file to S3, then copies into Redshift. Has options to - split a single file into multiple files, compress using gzip, and + r"""Load a file to S3, then copies into Redshift. + + Has options to split a single file into multiple files, compress using gzip, and upload to an S3 bucket with folders within the bucket. Notes @@ -341,8 +353,9 @@ def unload_and_copy( parallel_off=False, unload_options=None, ): - """``UNLOAD`` data from Redshift, with options to write to a flat file, - and store on S3. + """Unload data from Redshift. + + With options to write to a flat file and store on S3. Parameters ---------- @@ -390,27 +403,27 @@ def unload_and_copy( # data = [] s3path = self._generate_unload_path(s3_bucket, s3_folder) - ## configure unload options + # configure unload options if unload_options is None: unload_options = [] if delim: - unload_options.append("DELIMITER '{0}'".format(delim)) + unload_options.append(f"DELIMITER '{delim}'") if parallel_off: unload_options.append("PARALLEL OFF") - ## run unload + # run unload self.unload(query=query, s3path=s3path, unload_options=unload_options) - ## parse unloaded files + # parse unloaded files s3_download_list = self._unload_generated_files() if s3_download_list is None: logger.error("No files generated from unload") - raise Exception("No files generated from unload") + raise DBError("No files generated from unload") columns = self._get_column_names(query) if columns is None: logger.error("Unable to retrieve column names from exported data") - raise Exception("Unable to retrieve column names from exported data.") + raise DBError("Unable to retrieve column names from exported data.") # download files locally with same name local_list = self.download_list_from_s3(s3_download_list, raw_unload_path) @@ -423,8 +436,7 @@ def unload_and_copy( self.delete_list_from_s3(s3_download_list) def unload(self, query, s3path, unload_options=None): - """Executes the UNLOAD command to export a query from - Redshift to S3. + """Execute the UNLOAD command to export a query from Redshift to S3. Parameters ---------- @@ -462,10 +474,10 @@ def unload(self, query, s3path, unload_options=None): self.execute(sql, commit=True) except Exception as e: logger.error("Error running UNLOAD on redshift. err: %s", e) - raise DBError("Error running UNLOAD on redshift.") + raise DBError("Error running UNLOAD on redshift.") from e def _get_column_names(self, query): - """Gets a list of column names from the supplied query. + """Get a list of column names from the supplied query. Parameters ---------- @@ -477,22 +489,21 @@ def _get_column_names(self, query): list List of column names. Returns None if no columns were retrieved. """ - try: logger.info("Retrieving column names") - sql = "SELECT * FROM ({}) WHERE 1 = 0".format(query) + sql = f"SELECT * FROM ({query}) WHERE 1 = 0" self.execute(sql) - results = [desc for desc in self.cursor.description] + results = list(self.cursor.description) if len(results) > 0: return [result[0].strip() for result in results] else: return None - except Exception as e: + except Exception: logger.error("Error retrieving column names") raise def _unload_generated_files(self): - """Gets a list of files generated by the unload process + """Get a list of files generated by the unload process. Returns ------- @@ -511,7 +522,7 @@ def _unload_generated_files(self): return [result[0].strip() for result in results] else: return None - except Exception as e: + except Exception: logger.error("Error retrieving unloads generated files") raise @@ -556,7 +567,6 @@ def insert_dataframe_to_table( """ - import pandas as pd if columns: @@ -585,9 +595,7 @@ def insert_dataframe_to_table( + ")" ) column_sql = "(" + ",".join(list(metadata.keys())) + ")" - create_query = "CREATE TABLE {table_name} {create_join}".format( - table_name=table_name, create_join=create_join - ) + create_query = f"CREATE TABLE {table_name} {create_join}" self.execute(create_query) logger.info("New table has been created") @@ -611,9 +619,7 @@ def insert_dataframe_to_table( to_insert.append(none_row) string_join = ", ".join(to_insert) insert_query = ( - """INSERT INTO {table_name} {columns} VALUES {values}""".format( - table_name=table_name, columns=column_sql, values=string_join - ) + f"""INSERT INTO {table_name} {column_sql} VALUES {string_join}""" ) self.execute(insert_query, verbose=verbose) logger.info("Table insertion has completed") diff --git a/locopy/s3.py b/locopy/s3.py index 348cf4b..2229118 100644 --- a/locopy/s3.py +++ b/locopy/s3.py @@ -14,17 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""S3 Module +"""S3 Module. + Module to wrap the boto3 api usage and provide functionality to manage -multipart upload to S3 buckets +multipart upload to S3 buckets. """ + import os from boto3 import Session from boto3.s3.transfer import TransferConfig from botocore.client import Config -from .errors import ( +from locopy.errors import ( S3CredentialsError, S3DeletionError, S3DownloadError, @@ -32,16 +34,16 @@ S3InitializationError, S3UploadError, ) -from .logger import INFO, get_logger -from .utility import ProgressPercentage +from locopy.logger import INFO, get_logger +from locopy.utility import ProgressPercentage logger = get_logger(__name__, INFO) -class S3(object): - """ - S3 wrapper class which utilizes the boto3 library to push files to an S3 - bucket. +class S3: + """S3 wrapper class. + + Utilizes the boto3 library to push files to an S3 bucket. Parameters ---------- @@ -85,7 +87,6 @@ class S3(object): """ def __init__(self, profile=None, kms_key=None, **kwargs): - self.profile = profile self.kms_key = kms_key self.session = None @@ -100,7 +101,7 @@ def _set_session(self): logger.info("Initialized AWS session.") except Exception as e: logger.error("Error initializing AWS Session, err: %s", e) - raise S3Error("Error initializing AWS Session.") + raise S3Error("Error initializing AWS Session.") from e credentials = self.session.get_credentials() if credentials is None: raise S3CredentialsError("Credentials could not be set.") @@ -111,11 +112,12 @@ def _set_client(self): logger.info("Successfully initialized S3 client.") except Exception as e: logger.error("Error initializing S3 Client, err: %s", e) - raise S3InitializationError("Error initializing S3 Client.") + raise S3InitializationError("Error initializing S3 Client.") from e def _credentials_string(self): - """Returns a credentials string for the Redshift COPY or UNLOAD command, - containing credentials from the current session. + """Return a credentials string for the Redshift COPY or UNLOAD command. + + Containing credentials from the current session. Returns ------- @@ -131,7 +133,7 @@ def _credentials_string(self): return temp.format(creds.access_key, creds.secret_key) def _generate_s3_path(self, bucket, key): - """Will return the S3 file URL in the format S3://bucket/key + """Will return the S3 file URL in the format S3://bucket/key. Parameters ---------- @@ -146,11 +148,13 @@ def _generate_s3_path(self, bucket, key): str string of the S3 file URL in the format S3://bucket/key """ - return "s3://{0}/{1}".format(bucket, key) + return f"s3://{bucket}/{key}" def _generate_unload_path(self, bucket, folder): - """Will return the S3 file URL in the format s3://bucket/folder if a - valid (not None) folder is provided. Otherwise, returns s3://bucket + """Return the S3 file URL. + + If a valid (not None) folder is provided, returnsin the format s3://bucket/folder. + Otherwise, returns s3://bucket. Parameters ---------- @@ -168,9 +172,9 @@ def _generate_unload_path(self, bucket, folder): If folder is None, returns format s3://bucket """ if folder: - s3_path = "s3://{0}/{1}".format(bucket, folder) + s3_path = f"s3://{bucket}/{folder}" else: - s3_path = "s3://{0}".format(bucket) + s3_path = f"s3://{bucket}" return s3_path def upload_to_s3(self, local, bucket, key): @@ -204,13 +208,19 @@ def upload_to_s3(self, local, bucket, key): extra_args["SSEKMSKeyId"] = self.kms_key logger.info("Using KMS Keys for encryption") - logger.info("Uploading file to S3 bucket: %s", self._generate_s3_path(bucket, key)) + logger.info( + "Uploading file to S3 bucket: %s", self._generate_s3_path(bucket, key) + ) self.s3.upload_file( - local, bucket, key, ExtraArgs=extra_args, Callback=ProgressPercentage(local) + local, + bucket, + key, + ExtraArgs=extra_args, + Callback=ProgressPercentage(local), ) except Exception as e: logger.error("Error uploading to S3. err: %s", e) - raise S3UploadError("Error uploading to S3.") + raise S3UploadError("Error uploading to S3.") from e def upload_list_to_s3(self, local_list, bucket, folder=None): """ @@ -275,13 +285,14 @@ def download_from_s3(self, bucket, key, local): """ try: logger.info( - "Downloading file from S3 bucket: %s", self._generate_s3_path(bucket, key), + "Downloading file from S3 bucket: %s", + self._generate_s3_path(bucket, key), ) config = TransferConfig(max_concurrency=5) self.s3.download_file(bucket, key, local, Config=config) except Exception as e: logger.error("Error downloading from S3. err: %s", e) - raise S3DownloadError("Error downloading from S3.") + raise S3DownloadError("Error downloading from S3.") from e def download_list_from_s3(self, s3_list, local_path=None): """ @@ -330,11 +341,13 @@ def delete_from_s3(self, bucket, key): If there is a issue deleting from the S3 bucket """ try: - logger.info("Deleting file from S3 bucket: %s", self._generate_s3_path(bucket, key)) + logger.info( + "Deleting file from S3 bucket: %s", self._generate_s3_path(bucket, key) + ) self.s3.delete_object(Bucket=bucket, Key=key) except Exception as e: logger.error("Error deleting from S3. err: %s", e) - raise S3DeletionError("Error deleting from S3.") + raise S3DeletionError("Error deleting from S3.") from e def delete_list_from_s3(self, s3_list): """ @@ -351,9 +364,7 @@ def delete_list_from_s3(self, s3_list): self.delete_from_s3(s3_bucket, s3_key) def parse_s3_url(self, s3_url): - """ - Parse a string of the s3 url to extract the bucket and key. - scheme or not. + """Extract the bucket and key from a s3 url. Parameters ---------- diff --git a/locopy/snowflake.py b/locopy/snowflake.py index 60d3c5e..81fc842 100644 --- a/locopy/snowflake.py +++ b/locopy/snowflake.py @@ -14,18 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Snowflake Module +"""Snowflake Module. + Module to wrap a database adapter into a Snowflake class which can be used to connect to Snowflake, and run arbitrary code. """ + import os from pathlib import PurePath -from .database import Database -from .errors import DBError, S3CredentialsError -from .logger import INFO, get_logger -from .s3 import S3 -from .utility import find_column_type +from locopy.database import Database +from locopy.errors import DBError, S3CredentialsError +from locopy.logger import INFO, get_logger +from locopy.s3 import S3 +from locopy.utility import find_column_type logger = get_logger(__name__, INFO) @@ -85,8 +87,9 @@ def combine_options(options=None): - """Returns the ``copy_options`` or ``format_options`` attribute with spaces in between and as - a string. If options is ``None`` then return an empty string. + """Return the ``copy_options`` or ``format_options`` attribute. + + With spaces in between and as a string. If options is ``None`` then return an empty string. Parameters ---------- @@ -103,8 +106,9 @@ def combine_options(options=None): class Snowflake(S3, Database): - """Locopy class which manages connections to Snowflake. Inherits ``Database`` and implements - the specific ``COPY INTO`` functionality. + """Locopy class which manages connections to Snowflake. Inherits ``Database``. + + Implements the specific ``COPY INTO`` functionality. Parameters ---------- @@ -183,22 +187,23 @@ def __init__( Database.__init__(self, dbapi, config_yaml, **kwargs) def connect(self): - """Creates a connection to the Snowflake cluster by - setting the values of the ``conn`` and ``cursor`` attributes. + """Create a connection to the Snowflake cluster. + + Setg the values of the ``conn`` and ``cursor`` attributes. Raises ------ DBError If there is a problem establishing a connection to Snowflake. """ - super(Snowflake, self).connect() + super().connect() if self.connection.get("warehouse") is not None: - self.execute("USE WAREHOUSE {0}".format(self.connection["warehouse"])) + self.execute("USE WAREHOUSE {}".format(self.connection["warehouse"])) if self.connection.get("database") is not None: - self.execute("USE DATABASE {0}".format(self.connection["database"])) + self.execute("USE DATABASE {}".format(self.connection["database"])) if self.connection.get("schema") is not None: - self.execute("USE SCHEMA {0}".format(self.connection["schema"])) + self.execute("USE SCHEMA {}".format(self.connection["schema"])) def upload_to_internal( self, local, stage, parallel=4, auto_compress=True, overwrite=True @@ -231,9 +236,7 @@ def upload_to_internal( """ local_uri = PurePath(local).as_posix() self.execute( - "PUT 'file://{0}' {1} PARALLEL={2} AUTO_COMPRESS={3} OVERWRITE={4}".format( - local_uri, stage, parallel, auto_compress, overwrite - ) + f"PUT 'file://{local_uri}' {stage} PARALLEL={parallel} AUTO_COMPRESS={auto_compress} OVERWRITE={overwrite}" ) def download_from_internal(self, stage, local=None, parallel=10): @@ -255,16 +258,15 @@ def download_from_internal(self, stage, local=None, parallel=10): if local is None: local = os.getcwd() local_uri = PurePath(local).as_posix() - self.execute( - "GET {0} 'file://{1}' PARALLEL={2}".format(stage, local_uri, parallel) - ) + self.execute(f"GET {stage} 'file://{local_uri}' PARALLEL={parallel}") def copy( self, table_name, stage, file_type="csv", format_options=None, copy_options=None ): - """Executes the ``COPY INTO `` command to load CSV files from a stage into - a Snowflake table. If ``file_type == csv`` and ``format_options == None``, ``format_options`` - will default to: ``["FIELD_DELIMITER='|'", "SKIP_HEADER=0"]`` + """Load files from a stage into a Snowflake table. + + Execute the ``COPY INTO
`` command to If ``file_type == csv`` and ``format_options == None``, ``format_options`` + will default to: ``["FIELD_DELIMITER='|'", "SKIP_HEADER=0"]``. Parameters ---------- @@ -296,9 +298,7 @@ def copy( if file_type not in COPY_FORMAT_OPTIONS: raise ValueError( - "Invalid file_type. Must be one of {0}".format( - list(COPY_FORMAT_OPTIONS.keys()) - ) + f"Invalid file_type. Must be one of {list(COPY_FORMAT_OPTIONS.keys())}" ) if format_options is None and file_type == "csv": @@ -317,7 +317,7 @@ def copy( except Exception as e: logger.error("Error running COPY on Snowflake. err: %s", e) - raise DBError("Error running COPY on Snowflake.") + raise DBError("Error running COPY on Snowflake.") from e def unload( self, @@ -328,9 +328,12 @@ def unload( header=False, copy_options=None, ): - """Executes the ``COPY INTO `` command to export a query/table from - Snowflake to a stage. If ``file_type == csv`` and ``format_options == None``, ``format_options`` - will default to: ``["FIELD_DELIMITER='|'"]`` + """Export a query/table from Snowflake to a stage. + + Execute the ``COPY INTO `` command. + + If ``file_type == csv`` and ``format_options == None``, ``format_options`` + will default to: ``["FIELD_DELIMITER='|'"]``. Parameters ---------- @@ -364,9 +367,7 @@ def unload( if file_type not in COPY_FORMAT_OPTIONS: raise ValueError( - "Invalid file_type. Must be one of {0}".format( - list(UNLOAD_FORMAT_OPTIONS.keys()) - ) + f"Invalid file_type. Must be one of {list(UNLOAD_FORMAT_OPTIONS.keys())}" ) if format_options is None and file_type == "csv": @@ -390,13 +391,14 @@ def unload( self.execute(sql, commit=True) except Exception as e: logger.error("Error running UNLOAD on Snowflake. err: %s", e) - raise DBError("Error running UNLOAD on Snowflake.") + raise DBError("Error running UNLOAD on Snowflake.") from e def insert_dataframe_to_table( self, dataframe, table_name, columns=None, create=False, metadata=None ): - """ - Insert a Pandas dataframe to an existing table or a new table. In newer versions of the + """Insert a Pandas dataframe to an existing table or a new table. + + In newer versions of the python snowflake connector (v2.1.2+) users can call the ``write_pandas`` method from the cursor directly, ``insert_dataframe_to_table`` is a custom implementation and does not use ``write_pandas``. Instead of using ``COPY INTO`` the method builds a list of tuples to @@ -421,7 +423,6 @@ def insert_dataframe_to_table( metadata: dictionary, optional If metadata==None, it will be generated based on data """ - import pandas as pd if columns: @@ -434,7 +435,7 @@ def insert_dataframe_to_table( # create a list of tuples for insert to_insert = [] for row in dataframe.itertuples(index=False): - none_row = tuple([None if pd.isnull(val) else str(val) for val in row]) + none_row = tuple(None if pd.isnull(val) else str(val) for val in row) to_insert.append(none_row) if not create and metadata: @@ -457,22 +458,20 @@ def insert_dataframe_to_table( + ")" ) column_sql = "(" + ",".join(list(metadata.keys())) + ")" - create_query = "CREATE TABLE {table_name} {create_join}".format( - table_name=table_name, create_join=create_join - ) + create_query = f"CREATE TABLE {table_name} {create_join}" self.execute(create_query) logger.info("New table has been created") - insert_query = """INSERT INTO {table_name} {columns} VALUES {values}""".format( - table_name=table_name, columns=column_sql, values=string_join - ) + insert_query = f"""INSERT INTO {table_name} {column_sql} VALUES {string_join}""" logger.info("Inserting records...") self.execute(insert_query, params=to_insert, many=True) logger.info("Table insertion has completed") def to_dataframe(self, size=None): - """Return a dataframe of the last query results. This is just a convenience method. This + """Return a dataframe of the last query results. + + This is just a convenience method. This method overrides the base classes implementation in favour for the snowflake connectors built-in ``fetch_pandas_all`` when ``size==None``. If ``size != None`` then we will continue to use the existing functionality where we iterate through the cursor and build the @@ -492,4 +491,4 @@ def to_dataframe(self, size=None): if size is None and self.cursor._query_result_format == "arrow": return self.cursor.fetch_pandas_all() else: - return super(Snowflake, self).to_dataframe(size) + return super().to_dataframe(size) diff --git a/locopy/utility.py b/locopy/utility.py index a903b41..67e8253 100644 --- a/locopy/utility.py +++ b/locopy/utility.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utility Module -Module which utility functions for use within the application +"""Utility Module. + +Module which utility functions for use within the application. """ + import gzip import os import shutil @@ -27,9 +29,14 @@ import yaml -from .errors import (CompressionError, CredentialsError, LocopyConcatError, - LocopyIgnoreHeaderError, LocopySplitError) -from .logger import INFO, get_logger +from locopy.errors import ( + CompressionError, + CredentialsError, + LocopyConcatError, + LocopyIgnoreHeaderError, + LocopySplitError, +) +from locopy.logger import INFO, get_logger logger = get_logger(__name__, INFO) @@ -63,7 +70,7 @@ def write_file(data, delimiter, filepath, mode="w"): def compress_file(input_file, output_file): - """Compresses a file (gzip) + """Compresses a file (gzip). Parameters ---------- @@ -73,17 +80,16 @@ def compress_file(input_file, output_file): Path to write the compressed file """ try: - with open(input_file, "rb") as f_in: - with gzip.open(output_file, "wb") as f_out: - logger.info("compressing (gzip): %s to %s", input_file, output_file) - shutil.copyfileobj(f_in, f_out) + with open(input_file, "rb") as f_in, gzip.open(output_file, "wb") as f_out: + logger.info("compressing (gzip): %s to %s", input_file, output_file) + shutil.copyfileobj(f_in, f_out) except Exception as e: logger.error("Error compressing the file. err: %s", e) - raise CompressionError("Error compressing the file.") + raise CompressionError("Error compressing the file.") from e def compress_file_list(file_list): - """Compresses a list of files (gzip) and clean up the old files + """Compresses a list of files (gzip) and clean up the old files. Parameters ---------- @@ -97,7 +103,7 @@ def compress_file_list(file_list): gz appended) """ for i, f in enumerate(file_list): - gz = "{0}.gz".format(f) + gz = f"{f}.gz" compress_file(f, gz) file_list[i] = gz os.remove(f) # cleanup old files @@ -149,7 +155,7 @@ def split_file(input_file, output_file, splits=1, ignore_header=0): cpool = cycle(pool) logger.info("splitting file: %s into %s files", input_file, splits) # open output file handlers - files = [open("{0}.{1}".format(output_file, x), "wb") for x in pool] + files = [open(f"{output_file}.{x}", "wb") for x in pool] # noqa: SIM115 # open input file and send line to different handler with open(input_file, "rb") as f_in: # if we have a value in ignore_header then skip those many lines to start @@ -168,7 +174,7 @@ def split_file(input_file, output_file, splits=1, ignore_header=0): for x in pool: files[x].close() os.remove(files[x].name) - raise LocopySplitError("Error splitting the file.") + raise LocopySplitError("Error splitting the file.") from e def concatenate_files(input_list, output_file, remove=True): @@ -202,13 +208,15 @@ def concatenate_files(input_list, output_file, remove=True): os.remove(f) except Exception as e: logger.error("Error concateneating files. err: %s", e) - raise LocopyConcatError("Error concateneating files.") + raise LocopyConcatError("Error concateneating files.") from e def read_config_yaml(config_yaml): - """ - Reads a configuration YAML file to populate the database - connection attributes, and validate required ones. Example:: + """Read a configuration YAML file. + + Populate the database connection attributes, and validate required ones. + + Example:: host: my.redshift.cluster.com port: 5439 @@ -240,7 +248,7 @@ def read_config_yaml(config_yaml): locopy_yaml = yaml.safe_load(config_yaml) except Exception as e: logger.error("Error reading yaml. err: %s", e) - raise CredentialsError("Error reading yaml.") + raise CredentialsError("Error reading yaml.") from e return locopy_yaml @@ -275,7 +283,6 @@ def find_column_type(dataframe, warehouse_type: str): A dictionary of columns with their data type """ import re - from datetime import date, datetime import pandas as pd @@ -313,7 +320,9 @@ def validate_float_object(column): data = dataframe[column].dropna().reset_index(drop=True) if data.size == 0: column_type.append("varchar") - elif (data.dtype in ["datetime64[ns]", "M8[ns]"]) or (re.match("(datetime64\[ns\,\W)([a-zA-Z]+)(\])",str(data.dtype))): + elif (data.dtype in ["datetime64[ns]", "M8[ns]"]) or ( + re.match(r"(datetime64\[ns\,\W)([a-zA-Z]+)(\])", str(data.dtype)) + ): column_type.append("timestamp") elif str(data.dtype).lower().startswith("bool"): column_type.append("boolean") @@ -333,20 +342,22 @@ def validate_float_object(column): return OrderedDict(zip(list(dataframe.columns), column_type)) -class ProgressPercentage(object): - """ - ProgressPercentage class is used by the S3Transfer upload_file callback +class ProgressPercentage: + """ProgressPercentage class is used by the S3Transfer upload_file callback. + Please see the following url for more information: - http://boto3.readthedocs.org/en/latest/reference/customizations/s3.html#ref-s3transfer-usage + http://boto3.readthedocs.org/en/latest/reference/customizations/s3.html#ref-s3transfer-usage. """ def __init__(self, filename): - """ - Initiate the ProgressPercentage class, using the base information which - makes up a pipeline - Args: - filename (str): A name of the file which we will monitor the - progress of + """Initiate the ProgressPercentage class. + + Using the base information which makes up a pipeline + + Parameters + ---------- + filename (str): A name of the file which we will monitor the + progress of. """ self._filename = filename self._size = float(os.path.getsize(filename)) @@ -354,13 +365,15 @@ def __init__(self, filename): self._lock = threading.Lock() def __call__(self, bytes_amount): - # To simplify we'll assume this is hooked up - # to a single filename. + """Call as a function. + + To simplify we'll assume this is hooked up to a single filename. + """ with self._lock: self._seen_so_far += bytes_amount percentage = (self._seen_so_far / self._size) * 100 sys.stdout.write( - "\rTransfering [{0}] {1:.2f}%".format( + "\rTransfering [{}] {:.2f}%".format( "#" * int(percentage / 10), percentage ) ) @@ -368,9 +381,9 @@ def __call__(self, bytes_amount): def get_ignoreheader_number(options): - """ - Return the ``number_rows`` from ``IGNOREHEADER [ AS ] number_rows`` This doesn't not validate - that the ``AS`` is valid. + """Return the ``number_rows`` from ``IGNOREHEADER [ AS ] number_rows``. + + This doesn't validate that the ``AS`` is valid. Parameters ---------- diff --git a/pyproject.toml b/pyproject.toml index ef3555f..ffefeeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,9 +6,9 @@ authors = [ { name="Faisal Dosani", email="faisal.dosani@capitalone.com" }, ] license = {text = "Apache Software License"} -dependencies = ["boto3<=1.34.126,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.2,>=0.25.2", "numpy<=1.26.4,>=1.22.0"] +dependencies = ["boto3<=1.34.157,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.2,>=0.25.2", "numpy<=2.0.1,>=1.22.0"] -requires-python = ">=3.8.0" +requires-python = ">=3.9.0" classifiers = [ "Intended Audience :: Developers", "Natural Language :: English", @@ -16,7 +16,6 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -41,17 +40,68 @@ pg8000 = ["pg8000>=1.13.1"] snowflake = ["snowflake-connector-python[pandas]>=2.1.2"] docs = ["sphinx", "sphinx_rtd_theme"] tests = ["hypothesis", "pytest", "pytest-cov"] -qa = ["pre-commit", "black", "isort"] +qa = ["pre-commit", "ruff==0.5.7"] build = ["build", "twine", "wheel"] edgetest = ["edgetest", "edgetest-conda"] dev = ["locopy[tests]", "locopy[docs]", "locopy[qa]", "locopy[build]"] -[isort] -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true -line_length = 88 +# Linters, formatters and type checkers +[tool.ruff] +extend-include = ["*.ipynb"] +target-version = "py39" +src = ["src"] + + +[tool.ruff.lint] +preview = true +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "D", # pydocstyle + "I", # isort + "UP", # pyupgrade + "B", # flake8-bugbear + # "A", # flake8-builtins + "C4", # flake8-comprehensions + #"C901", # mccabe complexity + # "G", # flake8-logging-format + "T20", # flake8-print + "TID252", # flake8-tidy-imports ban relative imports + # "ARG", # flake8-unused-arguments + "SIM", # flake8-simplify + "NPY", # numpy rules + "LOG", # flake8-logging + "RUF", # Ruff errors +] + + +ignore = [ + "E111", # Check indentation level. Using formatter instead. + "E114", # Check indentation level. Using formatter instead. + "E117", # Check indentation level. Using formatter instead. + "E203", # Check whitespace. Using formatter instead. + "E501", # Line too long. Using formatter instead. + "D206", # Docstring indentation. Using formatter instead. + "D300", # Use triple single quotes. Using formatter instead. + "SIM108", # Use ternary operator instead of if-else blocks. + "SIM105", # Use `contextlib.suppress(FileNotFoundError)` instead of `try`-`except`-`pass` + "UP035", # `typing.x` is deprecated, use `x` instead + "UP006", # `typing.x` is deprecated, use `x` instead +] + + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402"] +"**/{tests,docs}/*" = ["E402", "D", "F841", "ARG"] + + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" + + +[tool.ruff.lint.pydocstyle] +convention = "numpy" [edgetest.envs.core] python_version = "3.9" diff --git a/tests/test_database.py b/tests/test_database.py index 3366759..205e605 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -27,7 +27,6 @@ import psycopg2 import pytest import snowflake.connector - from locopy import Database from locopy.errors import CredentialsError, DBError @@ -57,7 +56,12 @@ def test_database_constructor(credentials, dbapi): @pytest.mark.parametrize("dbapi", DBAPIS) def test_database_constructor_kwargs(dbapi): d = Database( - dbapi=dbapi, host="host", port="port", database="database", user="user", password="password" + dbapi=dbapi, + host="host", + port="port", + database="database", + user="user", + password="password", ) assert d.connection["host"] == "host" assert d.connection["port"] == "port" @@ -132,7 +136,11 @@ def test_connect(credentials, dbapi): b = Database(dbapi=dbapi, **credentials) b.connect() mock_connect.assert_called_with( - host="host", user="user", port="port", password="password", database="database" + host="host", + user="user", + port="port", + password="password", + database="database", ) credentials["extra"] = 123 @@ -182,11 +190,12 @@ def test_disconnect_no_conn(credentials, dbapi): @pytest.mark.parametrize("dbapi", DBAPIS) def test_execute(credentials, dbapi): - with mock.patch(dbapi.__name__ + ".connect") as mock_connect: - with Database(dbapi=dbapi, **credentials) as test: - print(test) - test.execute("SELECT * FROM some_table") - assert test.cursor.execute.called is True + with ( + mock.patch(dbapi.__name__ + ".connect") as mock_connect, + Database(dbapi=dbapi, **credentials) as test, + ): + test.execute("SELECT * FROM some_table") + assert test.cursor.execute.called is True @pytest.mark.parametrize("dbapi", DBAPIS) @@ -201,18 +210,24 @@ def test_execute_no_connection_exception(credentials, dbapi): @pytest.mark.parametrize("dbapi", DBAPIS) def test_execute_sql_exception(credentials, dbapi): - with mock.patch(dbapi.__name__ + ".connect") as mock_connect: - with Database(dbapi=dbapi, **credentials) as test: - test.cursor.execute.side_effect = Exception("SQL Exception") - with pytest.raises(DBError): - test.execute("SELECT * FROM some_table") + with ( + mock.patch(dbapi.__name__ + ".connect") as mock_connect, + Database(dbapi=dbapi, **credentials) as test, + ): + test.cursor.execute.side_effect = Exception("SQL Exception") + with pytest.raises(DBError): + test.execute("SELECT * FROM some_table") @pytest.mark.parametrize("dbapi", DBAPIS) @mock.patch("pandas.DataFrame") def test_to_dataframe_all(mock_pandas, credentials, dbapi): with mock.patch(dbapi.__name__ + ".connect") as mock_connect: - mock_connect.return_value.cursor.return_value.fetchall.return_value = [(1, 2), (2, 3), (3,)] + mock_connect.return_value.cursor.return_value.fetchall.return_value = [ + (1, 2), + (2, 3), + (3,), + ] with Database(dbapi=dbapi, **credentials) as test: test.execute("SELECT 'hello world' AS fld") df = test.to_dataframe() @@ -256,11 +271,17 @@ def test_get_column_names(credentials, dbapi): with Database(dbapi=dbapi, **credentials) as test: assert test.column_names() == ["col1", "col2"] - mock_connect.return_value.cursor.return_value.description = [("COL1",), ("COL2",)] + mock_connect.return_value.cursor.return_value.description = [ + ("COL1",), + ("COL2",), + ] with Database(dbapi=dbapi, **credentials) as test: assert test.column_names() == ["col1", "col2"] - mock_connect.return_value.cursor.return_value.description = (("COL1",), ("COL2",)) + mock_connect.return_value.cursor.return_value.description = ( + ("COL1",), + ("COL2",), + ) with Database(dbapi=dbapi, **credentials) as test: assert test.column_names() == ["col1", "col2"] @@ -274,8 +295,16 @@ def cols(): test.column_names = cols test.cursor = [(1, 2), (2, 3), (3,)] - assert list(test.to_dict()) == [{"col1": 1, "col2": 2}, {"col1": 2, "col2": 3}, {"col1": 3}] + assert list(test.to_dict()) == [ + {"col1": 1, "col2": 2}, + {"col1": 2, "col2": 3}, + {"col1": 3}, + ] test.column_names = cols test.cursor = [("a", 2), (2, 3), ("b",)] - assert list(test.to_dict()) == [{"col1": "a", "col2": 2}, {"col1": 2, "col2": 3}, {"col1": "b"}] + assert list(test.to_dict()) == [ + {"col1": "a", "col2": 2}, + {"col1": 2, "col2": 3}, + {"col1": "b"}, + ] diff --git a/tests/test_integration.py b/tests/test_integration.py index ca8fb87..343a690 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -25,14 +25,13 @@ from pathlib import Path import boto3 +import locopy import numpy as np import pandas as pd import pg8000 import psycopg2 import pytest -import locopy - DBAPIS = [pg8000, psycopg2] INTEGRATION_CREDS = str(Path.home()) + "/.locopyrc" S3_BUCKET = "locopy-integration-testing" @@ -61,7 +60,6 @@ def s3_bucket(): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_redshift_execute_single_rows(dbapi): - expected = pd.DataFrame({"field_1": [1], "field_2": [2]}) with locopy.Redshift(dbapi=dbapi, **CREDS_DICT) as test: test.execute("SELECT 1 AS field_1, 2 AS field_2 ") @@ -73,7 +71,6 @@ def test_redshift_execute_single_rows(dbapi): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_redshift_execute_multiple_rows(dbapi): - expected = pd.DataFrame({"field_1": [1, 2], "field_2": [1, 2]}) with locopy.Redshift(dbapi=dbapi, **CREDS_DICT) as test: test.execute( @@ -90,7 +87,6 @@ def test_redshift_execute_multiple_rows(dbapi): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_s3_upload_download_file(s3_bucket, dbapi): - s3 = locopy.S3(**CREDS_DICT) s3.upload_to_s3(LOCAL_FILE, S3_BUCKET, "myfile.txt") @@ -104,7 +100,6 @@ def test_s3_upload_download_file(s3_bucket, dbapi): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_copy(s3_bucket, dbapi): - with locopy.Redshift(dbapi=dbapi, **CREDS_DICT) as redshift: redshift.execute( "CREATE TEMPORARY TABLE locopy_integration_testing (id INTEGER, variable VARCHAR(20)) DISTKEY(variable)" @@ -135,7 +130,6 @@ def test_copy(s3_bucket, dbapi): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_copy_split_ignore(s3_bucket, dbapi): - with locopy.Redshift(dbapi=dbapi, **CREDS_DICT) as redshift: redshift.execute( "CREATE TEMPORARY TABLE locopy_integration_testing (id INTEGER, variable VARCHAR(20)) DISTKEY(variable)" @@ -163,13 +157,12 @@ def test_copy_split_ignore(s3_bucket, dbapi): for i, result in enumerate(results): assert result[0] == expected[i][0] assert result[1] == expected[i][1] - os.remove(LOCAL_FILE_HEADER + ".{0}".format(i)) + os.remove(LOCAL_FILE_HEADER + f".{i}") @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_unload(s3_bucket, dbapi): - with locopy.Redshift(dbapi=dbapi, **CREDS_DICT) as redshift: redshift.execute( "CREATE TEMPORARY TABLE locopy_integration_testing AS SELECT ('2017-12-31'::date + row_number() over (order by 1))::date from SVV_TABLES LIMIT 5" @@ -231,7 +224,6 @@ def test_unload_raw_unload_path(s3_bucket, dbapi): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_insert_dataframe_to_table(s3_bucket, dbapi): - with locopy.Redshift(dbapi=dbapi, **CREDS_DICT) as redshift: redshift.insert_dataframe_to_table(TEST_DF, "locopy_df_test", create=True) redshift.execute("SELECT a, b, c FROM locopy_df_test ORDER BY a ASC") diff --git a/tests/test_integration_sf.py b/tests/test_integration_sf.py index 60f9bcb..f5c7ab4 100644 --- a/tests/test_integration_sf.py +++ b/tests/test_integration_sf.py @@ -26,13 +26,12 @@ from pathlib import Path import boto3 +import locopy import numpy as np import pandas as pd import pytest import snowflake.connector -import locopy - DBAPIS = [snowflake.connector] INTEGRATION_CREDS = str(Path.home()) + os.sep + ".locopy-sfrc" S3_BUCKET = "locopy-integration-testing" @@ -60,7 +59,6 @@ def s3_bucket(): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_snowflake_execute_single_rows(dbapi): - expected = pd.DataFrame({"field_1": [1], "field_2": [2]}) with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test: test.execute("SELECT 1 AS field_1, 2 AS field_2 ") @@ -73,11 +71,12 @@ def test_snowflake_execute_single_rows(dbapi): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_snowflake_execute_multiple_rows(dbapi): - expected = pd.DataFrame({"field_1": [1, 2], "field_2": [1, 2]}) with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test: test.execute( - "SELECT 1 AS field_1, 1 AS field_2 " "UNION " "SELECT 2 AS field_1, 2 AS field_2" + "SELECT 1 AS field_1, 1 AS field_2 " + "UNION " + "SELECT 2 AS field_1, 2 AS field_2" ) df = test.to_dataframe() df.columns = [c.lower() for c in df.columns] @@ -89,7 +88,6 @@ def test_snowflake_execute_multiple_rows(dbapi): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_upload_download_internal(dbapi): - with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test: # delete if exists test.execute("REMOVE @~/staged/mock_file_dl.txt") @@ -114,7 +112,6 @@ def test_upload_download_internal(dbapi): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_copy(dbapi): - with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test: test.upload_to_internal(LOCAL_FILE, "@~/staged/") test.execute("USE SCHEMA {}".format(CREDS_DICT["schema"])) @@ -144,7 +141,6 @@ def test_copy(dbapi): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_copy_json(dbapi): - with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test: test.upload_to_internal(LOCAL_FILE_JSON, "@~/staged/") test.execute("USE SCHEMA {}".format(CREDS_DICT["schema"])) @@ -176,7 +172,6 @@ def test_copy_json(dbapi): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_to_dataframe(dbapi): - with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test: test.upload_to_internal(LOCAL_FILE_JSON, "@~/staged/") test.execute("USE SCHEMA {}".format(CREDS_DICT["schema"])) @@ -197,11 +192,17 @@ def test_to_dataframe(dbapi): result = test.to_dataframe() result.columns = [c.lower() for c in result.columns] expected = pd.DataFrame( - [('"Belmont"', '"92567"'), ('"Lexington"', '"75836"'), ('"Winchester"', '"89921"'),], + [ + ('"Belmont"', '"92567"'), + ('"Lexington"', '"75836"'), + ('"Winchester"', '"89921"'), + ], columns=["variable:location:city", "variable:price"], ) - assert (result["variable:location:city"] == expected["variable:location:city"]).all() + assert ( + result["variable:location:city"] == expected["variable:location:city"] + ).all() assert (result["variable:price"] == expected["variable:price"]).all() # with size of 2 @@ -211,11 +212,16 @@ def test_to_dataframe(dbapi): result = test.to_dataframe(size=2) result.columns = [c.lower() for c in result.columns] expected = pd.DataFrame( - [('"Belmont"', '"92567"'), ('"Lexington"', '"75836"'),], + [ + ('"Belmont"', '"92567"'), + ('"Lexington"', '"75836"'), + ], columns=["variable:location:city", "variable:price"], ) - assert (result["variable:location:city"] == expected["variable:location:city"]).all() + assert ( + result["variable:location:city"] == expected["variable:location:city"] + ).all() assert (result["variable:price"] == expected["variable:price"]).all() # with non-select query @@ -227,7 +233,6 @@ def test_to_dataframe(dbapi): @pytest.mark.integration @pytest.mark.parametrize("dbapi", DBAPIS) def test_insert_dataframe_to_table(dbapi): - with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test: test.insert_dataframe_to_table(TEST_DF, "test", create=True) test.execute("SELECT a, b, c FROM test ORDER BY a ASC") @@ -250,7 +255,15 @@ def test_insert_dataframe_to_table(dbapi): results = test.cursor.fetchall() test.execute("drop table if exists test_2") - expected = [(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e"), (6, "f"), (7, "g")] + expected = [ + (1, "a"), + (2, "b"), + (3, "c"), + (4, "d"), + (5, "e"), + (6, "f"), + (7, "g"), + ] assert len(expected) == len(results) for i, result in enumerate(results): @@ -280,5 +293,5 @@ def test_insert_dataframe_to_table(dbapi): assert len(expected) == len(results) for i, result in enumerate(results): - for j, item in enumerate(result): + for j, _item in enumerate(result): assert result[j] == expected[i][j] diff --git a/tests/test_redshift.py b/tests/test_redshift.py index d077ddf..55a5f2f 100644 --- a/tests/test_redshift.py +++ b/tests/test_redshift.py @@ -24,13 +24,12 @@ from collections import OrderedDict from unittest import mock +import locopy import pg8000 import psycopg2 import pytest - -import locopy from locopy import Redshift -from locopy.errors import CredentialsError, DBError +from locopy.errors import DBError PROFILE = "test" GOOD_CONFIG_YAML = """ @@ -50,7 +49,9 @@ def test_add_default_copy_options(): "COMPUPDATE ON", "TRUNCATECOLUMNS", ] - assert locopy.redshift.add_default_copy_options(["DATEFORMAT 'other'", "NULL AS 'blah'"]) == [ + assert locopy.redshift.add_default_copy_options( + ["DATEFORMAT 'other'", "NULL AS 'blah'"] + ) == [ "DATEFORMAT 'other'", "NULL AS 'blah'", "COMPUPDATE ON", @@ -59,9 +60,9 @@ def test_add_default_copy_options(): def test_combine_copy_options(): - assert locopy.redshift.combine_copy_options(locopy.redshift.add_default_copy_options()) == ( - "DATEFORMAT 'auto' COMPUPDATE " "ON TRUNCATECOLUMNS" - ) + assert locopy.redshift.combine_copy_options( + locopy.redshift.add_default_copy_options() + ) == ("DATEFORMAT 'auto' COMPUPDATE " "ON TRUNCATECOLUMNS") @pytest.mark.parametrize("dbapi", DBAPIS) @@ -70,7 +71,7 @@ def test_constructor(mock_session, credentials, dbapi): r = Redshift(profile=PROFILE, dbapi=dbapi, **credentials) mock_session.assert_called_with(profile_name=PROFILE) assert r.profile == PROFILE - assert r.kms_key == None + assert r.kms_key is None assert r.connection["host"] == "host" assert r.connection["port"] == "port" assert r.connection["database"] == "database" @@ -91,7 +92,7 @@ def test_constructor_yaml(mock_session, dbapi): r = Redshift(profile=PROFILE, dbapi=dbapi, config_yaml="some_config.yml") mock_session.assert_called_with(profile_name=PROFILE) assert r.profile == PROFILE - assert r.kms_key == None + assert r.kms_key is None assert r.connection["host"] == "host" assert r.connection["port"] == "port" assert r.connection["database"] == "database" @@ -140,21 +141,18 @@ def test_copy_parquet(mock_execute, mock_session, credentials, dbapi): r = Redshift(profile=PROFILE, dbapi=dbapi, **credentials) r.connect() r.copy("table", s3path="path", delim=None, copy_options=["PARQUET"]) - test_sql = ( - "COPY {0} FROM '{1}' " - "CREDENTIALS '{2}' " - "{3};".format("table", "path", r._credentials_string(), "PARQUET") + test_sql = "COPY {} FROM '{}' " "CREDENTIALS '{}' " "{};".format( + "table", "path", r._credentials_string(), "PARQUET" ) assert mock_execute.called_with(test_sql, commit=True) mock_execute.reset_mock() mock_session.reset_mock() r.copy("table", s3path="path", delim=None) - test_sql = ( - "COPY {0} FROM '{1}' " - "CREDENTIALS '{2}' " - "{3};".format( - "table", "path", r._credentials_string(), locopy.redshift.add_default_copy_options() - ) + test_sql = "COPY {} FROM '{}' " "CREDENTIALS '{}' " "{};".format( + "table", + "path", + r._credentials_string(), + locopy.redshift.add_default_copy_options(), ) assert mock_execute.called_with(test_sql, commit=True) @@ -261,7 +259,10 @@ def reset_mocks(): # mock_remove.assert_called_with("/path/local_file.2") mock_s3_upload.assert_has_calls(expected_calls_no_folder_gzip) mock_rs_copy.assert_called_with( - "table_name", "s3://s3_bucket/local_file", "|", copy_options=["SOME OPTION", "GZIP"] + "table_name", + "s3://s3_bucket/local_file", + "|", + copy_options=["SOME OPTION", "GZIP"], ) assert mock_s3_delete.called_with("s3_bucket", "local_file.0.gz") assert mock_s3_delete.called_with("s3_bucket", "local_file.1.gz") @@ -335,7 +336,10 @@ def reset_mocks(): "/path/local_file.txt", "s3_bucket", "test/local_file.txt" ) mock_rs_copy.assert_called_with( - "table_name", "s3://s3_bucket/test/local_file", "|", copy_options=["SOME OPTION"] + "table_name", + "s3://s3_bucket/test/local_file", + "|", + copy_options=["SOME OPTION"], ) assert not mock_s3_delete.called @@ -366,7 +370,10 @@ def reset_mocks(): # assert not mock_remove.called mock_s3_upload.assert_has_calls(expected_calls_folder) mock_rs_copy.assert_called_with( - "table_name", "s3://s3_bucket/test/local_file", "|", copy_options=["SOME OPTION"] + "table_name", + "s3://s3_bucket/test/local_file", + "|", + copy_options=["SOME OPTION"], ) assert mock_s3_delete.called_with("s3_bucket", "test/local_file.0") assert mock_s3_delete.called_with("s3_bucket", "test/local_file.1") @@ -467,7 +474,7 @@ def reset_mocks(): mock_split_file.return_value = ["/path/local_file.txt"] mock_compress_file_list.return_value = ["/path/local_file.txt.gz"] - #### neither ignore or split only + # neither ignore or split only r.load_and_copy("/path/local_file.txt", "s3_bucket", "table_name", delim="|") # assert @@ -482,7 +489,7 @@ def reset_mocks(): ) assert not mock_s3_delete.called, "Only delete when explicit" - #### ignore only + # ignore only reset_mocks() r.load_and_copy( "/path/local_file.txt", @@ -507,7 +514,7 @@ def reset_mocks(): ) assert not mock_s3_delete.called, "Only delete when explicit" - #### split only + # split only reset_mocks() mock_split_file.return_value = [ "/path/local_file.0", @@ -520,7 +527,12 @@ def reset_mocks(): "/path/local_file.2.gz", ] r.load_and_copy( - "/path/local_file", "s3_bucket", "table_name", delim="|", splits=3, delete_s3_after=True + "/path/local_file", + "s3_bucket", + "table_name", + delim="|", + splits=3, + delete_s3_after=True, ) # assert @@ -538,7 +550,7 @@ def reset_mocks(): assert mock_s3_delete.called_with("s3_bucket", "local_file.1.gz") assert mock_s3_delete.called_with("s3_bucket", "local_file.2.gz") - #### split and ignore + # split and ignore reset_mocks() mock_split_file.return_value = [ "/path/local_file.0", @@ -580,7 +592,6 @@ def reset_mocks(): @pytest.mark.parametrize("dbapi", DBAPIS) @mock.patch("locopy.s3.Session") def test_redshiftcopy(mock_session, credentials, dbapi): - with mock.patch(dbapi.__name__ + ".connect") as mock_connect: r = locopy.Redshift(dbapi=dbapi, **credentials) r.connect() @@ -589,13 +600,9 @@ def test_redshiftcopy(mock_session, credentials, dbapi): ( mock_connect.return_value.cursor.return_value.execute.assert_called_with( "COPY table FROM 's3bucket' CREDENTIALS " - "'aws_access_key_id={0};aws_secret_access_key={1};token={2}' " + f"'aws_access_key_id={r.session.get_credentials().access_key};aws_secret_access_key={r.session.get_credentials().secret_key};token={r.session.get_credentials().token}' " "DELIMITER '|' DATEFORMAT 'auto' COMPUPDATE ON " - "TRUNCATECOLUMNS;".format( - r.session.get_credentials().access_key, - r.session.get_credentials().secret_key, - r.session.get_credentials().token, - ), + "TRUNCATECOLUMNS;", (), ) ) @@ -606,13 +613,9 @@ def test_redshiftcopy(mock_session, credentials, dbapi): ( mock_connect.return_value.cursor.return_value.execute.assert_called_with( "COPY table FROM 's3bucket' CREDENTIALS " - "'aws_access_key_id={0};aws_secret_access_key={1};token={2}' " + f"'aws_access_key_id={r.session.get_credentials().access_key};aws_secret_access_key={r.session.get_credentials().secret_key};token={r.session.get_credentials().token}' " "DELIMITER '\t' DATEFORMAT 'auto' COMPUPDATE ON " - "TRUNCATECOLUMNS;".format( - r.session.get_credentials().access_key, - r.session.get_credentials().secret_key, - r.session.get_credentials().token, - ), + "TRUNCATECOLUMNS;", (), ) ) @@ -622,13 +625,9 @@ def test_redshiftcopy(mock_session, credentials, dbapi): ( mock_connect.return_value.cursor.return_value.execute.assert_called_with( "COPY table FROM 's3bucket' CREDENTIALS " - "'aws_access_key_id={0};aws_secret_access_key={1};token={2}' " + f"'aws_access_key_id={r.session.get_credentials().access_key};aws_secret_access_key={r.session.get_credentials().secret_key};token={r.session.get_credentials().token}' " "DATEFORMAT 'auto' COMPUPDATE ON " - "TRUNCATECOLUMNS;".format( - r.session.get_credentials().access_key, - r.session.get_credentials().secret_key, - r.session.get_credentials().token, - ), + "TRUNCATECOLUMNS;", (), ) ) @@ -638,7 +637,6 @@ def test_redshiftcopy(mock_session, credentials, dbapi): @mock.patch("locopy.s3.Session") @mock.patch("locopy.database.Database._is_connected") def test_redshiftcopy_exception(mock_connected, mock_session, credentials, dbapi): - with mock.patch(dbapi.__name__ + ".connect") as mock_connect: r = locopy.Redshift(dbapi=dbapi, **credentials) mock_connected.return_value = False @@ -691,13 +689,13 @@ def reset_mocks(): r = locopy.Redshift(dbapi=dbapi, **credentials) ## - ## Test 1: check that basic export pipeline functions are called + # Test 1: check that basic export pipeline functions are called mock_unload_generated_files.return_value = ["dummy_file"] mock_download_list_from_s3.return_value = ["s3.file"] mock_get_col_names.return_value = ["dummy_col_name"] mock_generate_unload_path.return_value = "dummy_s3_path" - ## ensure nothing is returned when read=False + # ensure nothing is returned when read=False r.unload_and_copy( query="query", s3_bucket="s3_bucket", @@ -710,7 +708,9 @@ def reset_mocks(): ) assert mock_unload_generated_files.called - assert not mock_write.called, "write_file should only be called " "if export_path != False" + assert not mock_write.called, ( + "write_file should only be called " "if export_path != False" + ) mock_generate_unload_path.assert_called_with("s3_bucket", None) mock_get_col_names.assert_called_with("query") mock_unload.assert_called_with( @@ -719,7 +719,7 @@ def reset_mocks(): assert not mock_delete_list_from_s3.called ## - ## Test 2: different delimiter + # Test 2: different delimiter reset_mocks() mock_unload_generated_files.return_value = ["dummy_file"] mock_download_list_from_s3.return_value = ["s3.file"] @@ -736,14 +736,16 @@ def reset_mocks(): parallel_off=True, ) - ## check that unload options are modified based on supplied args + # check that unload options are modified based on supplied args mock_unload.assert_called_with( - query="query", s3path="dummy_s3_path", unload_options=["DELIMITER '|'", "PARALLEL OFF"] + query="query", + s3path="dummy_s3_path", + unload_options=["DELIMITER '|'", "PARALLEL OFF"], ) assert not mock_delete_list_from_s3.called ## - ## Test 2.5: delimiter is none + # Test 2.5: delimiter is none reset_mocks() mock_unload_generated_files.return_value = ["dummy_file"] mock_download_list_from_s3.return_value = ["s3.file"] @@ -760,32 +762,32 @@ def reset_mocks(): parallel_off=True, ) - ## check that unload options are modified based on supplied args + # check that unload options are modified based on supplied args mock_unload.assert_called_with( query="query", s3path="dummy_s3_path", unload_options=["PARALLEL OFF"] ) assert not mock_delete_list_from_s3.called ## - ## Test 3: ensure exception is raised when no column names are retrieved + # Test 3: ensure exception is raised when no column names are retrieved reset_mocks() mock_unload_generated_files.return_value = ["dummy_file"] mock_generate_unload_path.return_value = "dummy_s3_path" mock_get_col_names.return_value = None - with pytest.raises(Exception): + with pytest.raises(DBError): r.unload_and_copy("query", "s3_bucket", None) ## - ## Test 4: ensure exception is raised when no files are returned + # Test 4: ensure exception is raised when no files are returned reset_mocks() mock_generate_unload_path.return_value = "dummy_s3_path" mock_get_col_names.return_value = ["dummy_col_name"] mock_unload_generated_files.return_value = None - with pytest.raises(Exception): + with pytest.raises(DBError): r.unload_and_copy("query", "s3_bucket", None) ## - ## Test 5: ensure file writing is initiated when export_path is supplied + # Test 5: ensure file writing is initiated when export_path is supplied reset_mocks() mock_get_col_names.return_value = ["dummy_col_name"] mock_download_list_from_s3.return_value = ["s3.file"] @@ -801,18 +803,20 @@ def reset_mocks(): delete_s3_after=True, parallel_off=False, ) - mock_concat.assert_called_with(mock_download_list_from_s3.return_value, "my_output.csv") + mock_concat.assert_called_with( + mock_download_list_from_s3.return_value, "my_output.csv" + ) assert mock_write.called assert mock_delete_list_from_s3.called_with("s3_bucket", "my_output.csv") ## - ## Test 6: raw_unload_path check + # Test 6: raw_unload_path check reset_mocks() mock_get_col_names.return_value = ["dummy_col_name"] mock_download_list_from_s3.return_value = ["s3.file"] mock_generate_unload_path.return_value = "dummy_s3_path" mock_unload_generated_files.return_value = ["/dummy_file"] - ## ensure nothing is returned when read=False + # ensure nothing is returned when read=False r.unload_and_copy( query="query", s3_bucket="s3_bucket", @@ -833,7 +837,7 @@ def test_unload_generated_files(mock_session, credentials, dbapi): r = locopy.Redshift(dbapi=dbapi, **credentials) r.connect() r._unload_generated_files() - assert r._unload_generated_files() == None + assert r._unload_generated_files() is None mock_connect.return_value.cursor.return_value.fetchall.return_value = [ ["File1 "], @@ -844,10 +848,10 @@ def test_unload_generated_files(mock_session, credentials, dbapi): r._unload_generated_files() assert r._unload_generated_files() == ["File1", "File2"] - mock_connect.return_value.cursor.return_value.execute.side_effect = Exception() + mock_connect.return_value.cursor.return_value.execute.side_effect = DBError() r = locopy.Redshift(dbapi=dbapi, **credentials) r.connect() - with pytest.raises(Exception): + with pytest.raises(DBError): r._unload_generated_files() @@ -857,19 +861,24 @@ def test_get_column_names(mock_session, credentials, dbapi): with mock.patch(dbapi.__name__ + ".connect") as mock_connect: r = locopy.Redshift(dbapi=dbapi, **credentials) r.connect() - assert r._get_column_names("query") == None + assert r._get_column_names("query") is None sql = "SELECT * FROM (query) WHERE 1 = 0" - assert mock_connect.return_value.cursor.return_value.execute.called_with(sql, ()) + assert mock_connect.return_value.cursor.return_value.execute.called_with( + sql, () + ) - mock_connect.return_value.cursor.return_value.description = [["COL1 "], ["COL2 "]] + mock_connect.return_value.cursor.return_value.description = [ + ["COL1 "], + ["COL2 "], + ] r = locopy.Redshift(dbapi=dbapi, **credentials) r.connect() assert r._get_column_names("query") == ["COL1", "COL2"] - mock_connect.return_value.cursor.return_value.execute.side_effect = Exception() + mock_connect.return_value.cursor.return_value.execute.side_effect = DBError() r = locopy.Redshift(dbapi=dbapi, **credentials) r.connect() - with pytest.raises(Exception): + with pytest.raises(DBError): r._get_column_names("query") @@ -888,13 +897,13 @@ def testunload(mock_session, credentials, dbapi): def testunload_no_connection(mock_session, credentials, dbapi): with mock.patch(dbapi.__name__ + ".connect") as mock_connect: r = locopy.Redshift(dbapi=dbapi, **credentials) - with pytest.raises(Exception): + with pytest.raises(DBError): r.unload("query", "path") - mock_connect.return_value.cursor.return_value.execute.side_effect = Exception() + mock_connect.return_value.cursor.return_value.execute.side_effect = DBError() r = locopy.Redshift(dbapi=dbapi, **credentials) r.connect() - with pytest.raises(Exception): + with pytest.raises(DBError): r.unload("query", "path") @@ -932,7 +941,9 @@ def testinsert_dataframe_to_table(mock_session, credentials, dbapi): test_df, "database.schema.test", create=True, - metadata=OrderedDict([("col1", "int"), ("col2", "varchar"), ("col3", "date")]), + metadata=OrderedDict( + [("col1", "int"), ("col2", "varchar"), ("col3", "date")] + ), ) mock_connect.return_value.cursor.return_value.execute.assert_any_call( @@ -943,11 +954,15 @@ def testinsert_dataframe_to_table(mock_session, credentials, dbapi): (), ) - r.insert_dataframe_to_table(test_df, "database.schema.test", create=False, batch_size=1) + r.insert_dataframe_to_table( + test_df, "database.schema.test", create=False, batch_size=1 + ) mock_connect.return_value.cursor.return_value.execute.assert_any_call( - "INSERT INTO database.schema.test (a,b,c) VALUES ('1', 'x', '2011-01-01')", () + "INSERT INTO database.schema.test (a,b,c) VALUES ('1', 'x', '2011-01-01')", + (), ) mock_connect.return_value.cursor.return_value.execute.assert_any_call( - "INSERT INTO database.schema.test (a,b,c) VALUES ('2', 'y', '2001-04-02')", () + "INSERT INTO database.schema.test (a,b,c) VALUES ('2', 'y', '2001-04-02')", + (), ) diff --git a/tests/test_s3.py b/tests/test_s3.py index 1d23b81..53f7f91 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -25,12 +25,11 @@ from unittest import mock import hypothesis.strategies as st +import locopy import pg8000 import psycopg2 import pytest from hypothesis import given - -import locopy from locopy.errors import ( S3CredentialsError, S3DeletionError, @@ -63,7 +62,7 @@ def test_mock_s3_session_profile_without_kms(profile, mock_session, dbapi): s = locopy.S3(profile=profile) mock_session.assert_called_with(profile_name=profile) - assert s.kms_key == None + assert s.kms_key is None @pytest.mark.parametrize("dbapi", DBAPIS) @@ -80,7 +79,7 @@ def test_mock_s3_session_profile_with_kms(input_kms_key, profile, mock_session, def test_mock_s3_session_profile_without_any(mock_session, dbapi): s = locopy.S3() mock_session.assert_called_with(profile_name=None) - assert s.kms_key == None + assert s.kms_key is None @pytest.mark.parametrize("dbapi", DBAPIS) @@ -122,8 +121,8 @@ def test_get_credentials(mock_cred, aws_creds): expected = "aws_access_key_id=access;" "aws_secret_access_key=secret" assert cred_string == expected - mock_cred.side_effect = Exception("Exception") - with pytest.raises(Exception): + mock_cred.side_effect = S3CredentialsError("Exception") + with pytest.raises(S3CredentialsError): locopy.S3() @@ -139,7 +138,10 @@ def test_generate_s3_path(mock_session): def test_generate_unload_path(mock_session): s = locopy.S3() assert s._generate_unload_path("TEST", "FOLDER/") == "s3://TEST/FOLDER/" - assert s._generate_unload_path("TEST SPACE", "FOLDER SPACE/") == "s3://TEST SPACE/FOLDER SPACE/" + assert ( + s._generate_unload_path("TEST SPACE", "FOLDER SPACE/") + == "s3://TEST SPACE/FOLDER SPACE/" + ) assert s._generate_unload_path("TEST", "PREFIX") == "s3://TEST/PREFIX" assert s._generate_unload_path("TEST", None) == "s3://TEST" @@ -267,7 +269,10 @@ def test_download_from_s3(mock_session, mock_config): s = locopy.S3() s.download_from_s3(S3_DEFAULT_BUCKET, LOCAL_TEST_FILE, LOCAL_TEST_FILE) s.s3.download_file.assert_called_with( - S3_DEFAULT_BUCKET, LOCAL_TEST_FILE, os.path.basename(LOCAL_TEST_FILE), Config=mock_config() + S3_DEFAULT_BUCKET, + LOCAL_TEST_FILE, + os.path.basename(LOCAL_TEST_FILE), + Config=mock_config(), ) mock_config.side_effect = Exception() @@ -316,7 +321,9 @@ def test_delete_list_from_s3_multiple_with_folder(mock_session, mock_delete): mock.call("test_bucket", "test_folder/test.2"), ] s = locopy.S3() - s.delete_list_from_s3(["test_bucket/test_folder/test.1", "test_bucket/test_folder/test.2"]) + s.delete_list_from_s3( + ["test_bucket/test_folder/test.1", "test_bucket/test_folder/test.2"] + ) mock_delete.assert_has_calls(calls) @@ -331,7 +338,9 @@ def test_delete_list_from_s3_multiple_without_folder(mock_session, mock_delete): @mock.patch("locopy.s3.S3.delete_from_s3") @mock.patch("locopy.s3.Session") -def test_delete_list_from_s3_single_with_folder_and_special_chars(mock_session, mock_delete): +def test_delete_list_from_s3_single_with_folder_and_special_chars( + mock_session, mock_delete +): calls = [mock.call("test_bucket", r"test_folder/#$#@$@#$dffksdojfsdf\\\\\/test.1")] s = locopy.S3() s.delete_list_from_s3([r"test_bucket/test_folder/#$#@$@#$dffksdojfsdf\\\\\/test.1"]) @@ -344,22 +353,33 @@ def test_delete_list_from_s3_exception(mock_session, mock_delete): s = locopy.S3() mock_delete.side_effect = S3UploadError("Upload Exception") with pytest.raises(S3UploadError): - s.delete_list_from_s3(["test_bucket/test_folder/test.1", "test_bucket/test_folder/test.2"]) + s.delete_list_from_s3( + ["test_bucket/test_folder/test.1", "test_bucket/test_folder/test.2"] + ) @mock.patch("locopy.s3.Session") def test_parse_s3_url(mock_session): s = locopy.S3() - assert s.parse_s3_url("s3://bucket/folder/file.txt") == ("bucket", "folder/file.txt") + assert s.parse_s3_url("s3://bucket/folder/file.txt") == ( + "bucket", + "folder/file.txt", + ) assert s.parse_s3_url("s3://bucket/folder/") == ("bucket", "folder/") assert s.parse_s3_url("s3://bucket") == ("bucket", "") - assert s.parse_s3_url(r"s3://bucket/!@#$%\\\/file.txt") == ("bucket", r"!@#$%\\\/file.txt") + assert s.parse_s3_url(r"s3://bucket/!@#$%\\\/file.txt") == ( + "bucket", + r"!@#$%\\\/file.txt", + ) assert s.parse_s3_url("s3://") == ("", "") assert s.parse_s3_url("bucket/folder/file.txt") == ("bucket", "folder/file.txt") assert s.parse_s3_url("bucket/folder/") == ("bucket", "folder/") assert s.parse_s3_url("bucket") == ("bucket", "") - assert s.parse_s3_url(r"bucket/!@#$%\\\/file.txt") == ("bucket", r"!@#$%\\\/file.txt") + assert s.parse_s3_url(r"bucket/!@#$%\\\/file.txt") == ( + "bucket", + r"!@#$%\\\/file.txt", + ) assert s.parse_s3_url("") == ("", "") @@ -394,7 +414,10 @@ def test_download_list_from_s3_multiple(mock_session, mock_download): ] s = locopy.S3() res = s.download_list_from_s3(["s3://bucket/test.1", "s3://bucket/test.2"]) - assert res == [os.path.join(os.getcwd(), "test.1"), os.path.join(os.getcwd(), "test.2")] + assert res == [ + os.path.join(os.getcwd(), "test.1"), + os.path.join(os.getcwd(), "test.2"), + ] mock_download.assert_has_calls(calls) @@ -407,8 +430,13 @@ def test_download_list_from_s3_multiple_with_localpath(mock_session, mock_downlo mock.call("bucket", "test.2", os.path.join(tmp_path.name, "test.2")), ] s = locopy.S3() - res = s.download_list_from_s3(["s3://bucket/test.1", "s3://bucket/test.2"], tmp_path.name) - assert res == [os.path.join(tmp_path.name, "test.1"), os.path.join(tmp_path.name, "test.2")] + res = s.download_list_from_s3( + ["s3://bucket/test.1", "s3://bucket/test.2"], tmp_path.name + ) + assert res == [ + os.path.join(tmp_path.name, "test.1"), + os.path.join(tmp_path.name, "test.2"), + ] mock_download.assert_has_calls(calls) tmp_path.cleanup() diff --git a/tests/test_snowflake.py b/tests/test_snowflake.py index a6b4aa2..d24e76b 100644 --- a/tests/test_snowflake.py +++ b/tests/test_snowflake.py @@ -27,13 +27,12 @@ from unittest import mock import hypothesis.strategies as s +import locopy import pytest import snowflake.connector from hypothesis import HealthCheck, given, settings - -import locopy from locopy import Snowflake -from locopy.errors import CredentialsError, DBError +from locopy.errors import DBError PROFILE = "test" KMS = "kms_test" @@ -159,84 +158,92 @@ def test_with_connect(mock_session, sf_credentials): sf.conn.cursor.return_value.execute.assert_any_call("USE SCHEMA schema", ()) mock_connect.side_effect = Exception("Connect Exception") - with pytest.raises(DBError): - with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf: - sf.cursor + with ( + pytest.raises(DBError), + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + sf.cursor # noqa: B018 @mock.patch("locopy.s3.Session") def test_upload_to_internal(mock_session, sf_credentials): - with mock.patch("snowflake.connector.connect") as mock_connect: - with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf: - sf.upload_to_internal("/some/file", "@~/internal") - sf.conn.cursor.return_value.execute.assert_called_with( - "PUT 'file:///some/file' @~/internal PARALLEL=4 AUTO_COMPRESS=True OVERWRITE=True", - (), - ) + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + sf.upload_to_internal("/some/file", "@~/internal") + sf.conn.cursor.return_value.execute.assert_called_with( + "PUT 'file:///some/file' @~/internal PARALLEL=4 AUTO_COMPRESS=True OVERWRITE=True", + (), + ) - sf.upload_to_internal( - "/some/file", "@~/internal", parallel=99, auto_compress=False - ) - sf.conn.cursor.return_value.execute.assert_called_with( - "PUT 'file:///some/file' @~/internal PARALLEL=99 AUTO_COMPRESS=False OVERWRITE=True", - (), - ) + sf.upload_to_internal( + "/some/file", "@~/internal", parallel=99, auto_compress=False + ) + sf.conn.cursor.return_value.execute.assert_called_with( + "PUT 'file:///some/file' @~/internal PARALLEL=99 AUTO_COMPRESS=False OVERWRITE=True", + (), + ) - sf.upload_to_internal("/some/file", "@~/internal", overwrite=False) - sf.conn.cursor.return_value.execute.assert_called_with( - "PUT 'file:///some/file' @~/internal PARALLEL=4 AUTO_COMPRESS=True OVERWRITE=False", - (), - ) + sf.upload_to_internal("/some/file", "@~/internal", overwrite=False) + sf.conn.cursor.return_value.execute.assert_called_with( + "PUT 'file:///some/file' @~/internal PARALLEL=4 AUTO_COMPRESS=True OVERWRITE=False", + (), + ) - # exception - sf.conn.cursor.return_value.execute.side_effect = Exception("PUT Exception") - with pytest.raises(DBError): - sf.upload_to_internal("/some/file", "@~/internal") + # exception + sf.conn.cursor.return_value.execute.side_effect = Exception("PUT Exception") + with pytest.raises(DBError): + sf.upload_to_internal("/some/file", "@~/internal") @mock.patch("locopy.snowflake.PurePath", new=PureWindowsPath) @mock.patch("locopy.s3.Session") def test_upload_to_internal_windows(mock_session, sf_credentials): - with mock.patch("snowflake.connector.connect") as mock_connect: - with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf: - - sf.upload_to_internal(r"C:\some\file", "@~/internal") - sf.conn.cursor.return_value.execute.assert_called_with( - "PUT 'file://C:/some/file' @~/internal PARALLEL=4 AUTO_COMPRESS=True OVERWRITE=True", - (), - ) + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + sf.upload_to_internal(r"C:\some\file", "@~/internal") + sf.conn.cursor.return_value.execute.assert_called_with( + "PUT 'file://C:/some/file' @~/internal PARALLEL=4 AUTO_COMPRESS=True OVERWRITE=True", + (), + ) @mock.patch("locopy.s3.Session") def test_download_from_internal(mock_session, sf_credentials): - with mock.patch("snowflake.connector.connect") as mock_connect: - with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf: - sf.download_from_internal("@~/internal", "/some/file") - sf.conn.cursor.return_value.execute.assert_called_with( - "GET @~/internal 'file:///some/file' PARALLEL=10", () - ) + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + sf.download_from_internal("@~/internal", "/some/file") + sf.conn.cursor.return_value.execute.assert_called_with( + "GET @~/internal 'file:///some/file' PARALLEL=10", () + ) - sf.download_from_internal("@~/internal", "/some/file", parallel=99) - sf.conn.cursor.return_value.execute.assert_called_with( - "GET @~/internal 'file:///some/file' PARALLEL=99", () - ) + sf.download_from_internal("@~/internal", "/some/file", parallel=99) + sf.conn.cursor.return_value.execute.assert_called_with( + "GET @~/internal 'file:///some/file' PARALLEL=99", () + ) - # exception - sf.conn.cursor.return_value.execute.side_effect = Exception("GET Exception") - with pytest.raises(DBError): - sf.download_from_internal("@~/internal", "/some/file") + # exception + sf.conn.cursor.return_value.execute.side_effect = Exception("GET Exception") + with pytest.raises(DBError): + sf.download_from_internal("@~/internal", "/some/file") @mock.patch("locopy.snowflake.PurePath", new=PureWindowsPath) @mock.patch("locopy.s3.Session") def test_download_from_internal_windows(mock_session, sf_credentials): - with mock.patch("snowflake.connector.connect") as mock_connect: - with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf: - - sf.download_from_internal("@~/internal", r"C:\some\file") - sf.conn.cursor.return_value.execute.assert_called_with( - "GET @~/internal 'file://C:/some/file' PARALLEL=10", () - ) + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + sf.download_from_internal("@~/internal", r"C:\some\file") + sf.conn.cursor.return_value.execute.assert_called_with( + "GET @~/internal 'file://C:/some/file' PARALLEL=10", () + ) @pytest.mark.parametrize( @@ -275,42 +282,40 @@ def test_download_from_internal_windows(mock_session, sf_credentials): def test_copy( mock_session, file_type, format_options, copy_options, expected, sf_credentials ): - with mock.patch("snowflake.connector.connect") as mock_connect: - with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf: - - sf.copy( - "table_name", - "@~/stage", - file_type=file_type, - format_options=format_options, - copy_options=copy_options, - ) - sf.conn.cursor.return_value.execute.assert_called_with( - "COPY INTO table_name FROM '@~/stage' FILE_FORMAT = {0}".format( - expected - ), - (), - ) + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + sf.copy( + "table_name", + "@~/stage", + file_type=file_type, + format_options=format_options, + copy_options=copy_options, + ) + sf.conn.cursor.return_value.execute.assert_called_with( + f"COPY INTO table_name FROM '@~/stage' FILE_FORMAT = {expected}", + (), + ) @mock.patch("locopy.s3.Session") def test_copy_exception(mock_session, sf_credentials): - with mock.patch("snowflake.connector.connect") as mock_connect: - with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf: - - with pytest.raises(ValueError): - sf.copy("table_name", "@~/stage", file_type="unknown") - - # exception - sf.conn.cursor.return_value.execute.side_effect = Exception( - "COPY Exception" - ) - with pytest.raises(DBError): - sf.copy("table_name", "@~/stage") + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + with pytest.raises(ValueError): + sf.copy("table_name", "@~/stage", file_type="unknown") + + # exception + sf.conn.cursor.return_value.execute.side_effect = Exception("COPY Exception") + with pytest.raises(DBError): + sf.copy("table_name", "@~/stage") - sf.conn = None - with pytest.raises(DBError): - sf.copy("table_name", "@~/stage") + sf.conn = None + with pytest.raises(DBError): + sf.copy("table_name", "@~/stage") @pytest.mark.parametrize( @@ -359,41 +364,41 @@ def test_unload( expected, sf_credentials, ): - with mock.patch("snowflake.connector.connect") as mock_connect: - with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf: - - sf.unload( - "@~/stage", - "table_name", - file_type=file_type, - format_options=format_options, - header=header, - copy_options=copy_options, - ) - sf.conn.cursor.return_value.execute.assert_called_with( - "COPY INTO @~/stage FROM table_name FILE_FORMAT = {0}".format(expected), - (), - ) + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + sf.unload( + "@~/stage", + "table_name", + file_type=file_type, + format_options=format_options, + header=header, + copy_options=copy_options, + ) + sf.conn.cursor.return_value.execute.assert_called_with( + f"COPY INTO @~/stage FROM table_name FILE_FORMAT = {expected}", + (), + ) @mock.patch("locopy.s3.Session") def test_unload_exception(mock_session, sf_credentials): - with mock.patch("snowflake.connector.connect") as mock_connect: - with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf: - - with pytest.raises(ValueError): - sf.unload("table_name", "@~/stage", file_type="unknown") - - # exception - sf.conn.cursor.return_value.execute.side_effect = Exception( - "UNLOAD Exception" - ) - with pytest.raises(DBError): - sf.unload("@~/stage", "table_name") + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + with pytest.raises(ValueError): + sf.unload("table_name", "@~/stage", file_type="unknown") + + # exception + sf.conn.cursor.return_value.execute.side_effect = Exception("UNLOAD Exception") + with pytest.raises(DBError): + sf.unload("@~/stage", "table_name") - sf.conn = None - with pytest.raises(DBError): - sf.unload("@~/stage", "table_name") + sf.conn = None + with pytest.raises(DBError): + sf.unload("@~/stage", "table_name") @mock.patch("locopy.s3.Session") @@ -401,18 +406,20 @@ def test_to_pandas(mock_session, sf_credentials): import pandas as pd test_df = pd.read_csv(os.path.join(CURR_DIR, "data", "mock_dataframe.txt"), sep=",") - with mock.patch("snowflake.connector.connect") as mock_connect: - with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf: - sf.cursor._query_result_format = "arrow" - sf.to_dataframe() - sf.conn.cursor.return_value.fetch_pandas_all.assert_called_with() + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + sf.cursor._query_result_format = "arrow" + sf.to_dataframe() + sf.conn.cursor.return_value.fetch_pandas_all.assert_called_with() - sf.cursor._query_result_format = "json" - sf.to_dataframe() - sf.conn.cursor.return_value.fetchall.assert_called_with() + sf.cursor._query_result_format = "json" + sf.to_dataframe() + sf.conn.cursor.return_value.fetchall.assert_called_with() - sf.to_dataframe(5) - sf.conn.cursor.return_value.fetchmany.assert_called_with(5) + sf.to_dataframe(5) + sf.conn.cursor.return_value.fetchmany.assert_called_with(5) @mock.patch("locopy.s3.Session") @@ -420,61 +427,63 @@ def test_insert_dataframe_to_table(mock_session, sf_credentials): import pandas as pd test_df = pd.read_csv(os.path.join(CURR_DIR, "data", "mock_dataframe.txt"), sep=",") - with mock.patch("snowflake.connector.connect") as mock_connect: - with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf: - sf.insert_dataframe_to_table(test_df, "database.schema.test") - sf.conn.cursor.return_value.executemany.assert_called_with( - "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)", - [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], - ) + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + sf.insert_dataframe_to_table(test_df, "database.schema.test") + sf.conn.cursor.return_value.executemany.assert_called_with( + "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)", + [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], + ) - sf.insert_dataframe_to_table(test_df, "database.schema.test", create=True) - sf.conn.cursor.return_value.execute.assert_any_call( - "CREATE TABLE database.schema.test (a int,b varchar,c date)", () - ) - sf.conn.cursor.return_value.executemany.assert_called_with( - "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)", - [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], - ) + sf.insert_dataframe_to_table(test_df, "database.schema.test", create=True) + sf.conn.cursor.return_value.execute.assert_any_call( + "CREATE TABLE database.schema.test (a int,b varchar,c date)", () + ) + sf.conn.cursor.return_value.executemany.assert_called_with( + "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)", + [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], + ) - sf.insert_dataframe_to_table( - test_df, "database.schema.test", columns=["a", "b"] - ) + sf.insert_dataframe_to_table( + test_df, "database.schema.test", columns=["a", "b"] + ) - sf.conn.cursor.return_value.executemany.assert_called_with( - "INSERT INTO database.schema.test (a,b) VALUES (%s,%s)", - [("1", "x"), ("2", "y")], - ) + sf.conn.cursor.return_value.executemany.assert_called_with( + "INSERT INTO database.schema.test (a,b) VALUES (%s,%s)", + [("1", "x"), ("2", "y")], + ) - sf.insert_dataframe_to_table( - test_df, - "database.schema.test", - create=True, - metadata=OrderedDict( - [("col1", "int"), ("col2", "varchar"), ("col3", "date")] - ), - ) + sf.insert_dataframe_to_table( + test_df, + "database.schema.test", + create=True, + metadata=OrderedDict( + [("col1", "int"), ("col2", "varchar"), ("col3", "date")] + ), + ) - sf.conn.cursor.return_value.execute.assert_any_call( - "CREATE TABLE database.schema.test (col1 int,col2 varchar,col3 date)", - (), - ) - sf.conn.cursor.return_value.executemany.assert_called_with( - "INSERT INTO database.schema.test (col1,col2,col3) VALUES (%s,%s,%s)", - [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], - ) + sf.conn.cursor.return_value.execute.assert_any_call( + "CREATE TABLE database.schema.test (col1 int,col2 varchar,col3 date)", + (), + ) + sf.conn.cursor.return_value.executemany.assert_called_with( + "INSERT INTO database.schema.test (col1,col2,col3) VALUES (%s,%s,%s)", + [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], + ) - sf.insert_dataframe_to_table( - test_df, - "database.schema.test", - create=False, - metadata=OrderedDict( - [("col1", "int"), ("col2", "varchar"), ("col3", "date")] - ), - ) + sf.insert_dataframe_to_table( + test_df, + "database.schema.test", + create=False, + metadata=OrderedDict( + [("col1", "int"), ("col2", "varchar"), ("col3", "date")] + ), + ) - # mock_session.warn.assert_called_with('Metadata will not be used because create is set to False.') - sf.conn.cursor.return_value.executemany.assert_called_with( - "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)", - [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], - ) + # mock_session.warn.assert_called_with('Metadata will not be used because create is set to False.') + sf.conn.cursor.return_value.executemany.assert_called_with( + "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)", + [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], + ) diff --git a/tests/test_utility.py b/tests/test_utility.py index bd32612..ec6f8f2 100644 --- a/tests/test_utility.py +++ b/tests/test_utility.py @@ -20,23 +20,31 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import datetime import os import sys from io import StringIO from itertools import cycle from pathlib import Path from unittest import mock -import datetime - -import pytest import locopy.utility as util -from locopy.errors import (CompressionError, CredentialsError, - LocopyConcatError, LocopyIgnoreHeaderError, - LocopySplitError) -from locopy.utility import (compress_file, compress_file_list, - concatenate_files, find_column_type, - get_ignoreheader_number, split_file) +import pytest +from locopy.errors import ( + CompressionError, + CredentialsError, + LocopyConcatError, + LocopyIgnoreHeaderError, + LocopySplitError, +) +from locopy.utility import ( + compress_file, + compress_file_list, + concatenate_files, + find_column_type, + get_ignoreheader_number, + split_file, +) GOOD_CONFIG_YAML = """host: my.redshift.cluster.com port: 1234 @@ -56,7 +64,7 @@ def cleanup(splits): def compare_file_contents(base_file, check_files): - check_files = cycle([open(x, "rb") for x in check_files]) + check_files = cycle([open(x, "rb") for x in check_files]) # noqa: SIM115 with open(base_file, "rb") as base: for line in base: cfile = next(check_files) @@ -211,7 +219,7 @@ def test_split_file_exception(): with pytest.raises(LocopySplitError): split_file(input_file, output_file, "Test") - with mock.patch("{0}.next".format(builtin_module_name)) as mock_next: + with mock.patch(f"{builtin_module_name}.next") as mock_next: mock_next.side_effect = Exception("SomeException") with pytest.raises(LocopySplitError): @@ -229,7 +237,7 @@ def test_split_file_exception(): @mock.patch("locopy.utility.open", mock.mock_open(read_data=GOOD_CONFIG_YAML)) def test_read_config_yaml_good(): actual = util.read_config_yaml("filename.yml") - assert set(actual.keys()) == set(["host", "port", "database", "user", "password"]) + assert set(actual.keys()) == {"host", "port", "database", "user", "password"} assert actual["host"] == "my.redshift.cluster.com" assert actual["port"] == 1234 assert actual["database"] == "db" @@ -239,7 +247,7 @@ def test_read_config_yaml_good(): def test_read_config_yaml_io(): actual = util.read_config_yaml(StringIO(GOOD_CONFIG_YAML)) - assert set(actual.keys()) == set(["host", "port", "database", "user", "password"]) + assert set(actual.keys()) == {"host", "port", "database", "user", "password"} assert actual["host"] == "my.redshift.cluster.com" assert actual["port"] == 1234 assert actual["database"] == "db" @@ -258,7 +266,8 @@ def test_concatenate_files(): with mock.patch("locopy.utility.os.remove") as mock_remove: concatenate_files(inputs, output) assert mock_remove.call_count == 3 - assert [int(line.rstrip("\n")) for line in open(output)] == list(range(1, 16)) + with open(output) as f: + assert [int(line.rstrip("\n")) for line in f] == list(range(1, 16)) os.remove(output) @@ -275,7 +284,6 @@ def test_concatenate_files_exception(): def test_find_column_type(): - from decimal import Decimal import pandas as pd @@ -341,29 +349,27 @@ def test_find_column_type(): assert find_column_type(input_text, "snowflake") == output_text_snowflake assert find_column_type(input_text, "redshift") == output_text_redshift -def test_find_column_type_new(): - - from decimal import Decimal +def test_find_column_type_new(): import pandas as pd input_text = pd.DataFrame.from_dict( - { - "a": [1], - "b": [pd.Timestamp('2017-01-01T12+0')], - "c": [1.2], - "d": ["a"], - "e": [True] - } -) + { + "a": [1], + "b": [pd.Timestamp("2017-01-01T12+0")], + "c": [1.2], + "d": ["a"], + "e": [True], + } + ) input_text = input_text.astype( dtype={ - "a": pd.Int64Dtype(), - "b": pd.DatetimeTZDtype(tz=datetime.timezone.utc), - "c": pd.Float64Dtype(), - "d": pd.StringDtype(), - "e": pd.BooleanDtype() + "a": pd.Int64Dtype(), + "b": pd.DatetimeTZDtype(tz=datetime.timezone.utc), + "c": pd.Float64Dtype(), + "d": pd.StringDtype(), + "e": pd.BooleanDtype(), } ) @@ -375,7 +381,7 @@ def test_find_column_type_new(): "e": "boolean", } - output_text_redshift = { + output_text_redshift = { "a": "int", "b": "timestamp", "c": "float", @@ -387,7 +393,6 @@ def test_find_column_type_new(): assert find_column_type(input_text, "redshift") == output_text_redshift - def test_get_ignoreheader_number(): assert ( get_ignoreheader_number(