diff --git a/.github/workflows/edgetest.yml b/.github/workflows/edgetest.yml
index f699c1f..a690653 100644
--- a/.github/workflows/edgetest.yml
+++ b/.github/workflows/edgetest.yml
@@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
name: running edgetest
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
with:
ref: develop
- name: Copy files for locopy
diff --git a/.github/workflows/publish-docs.yml b/.github/workflows/publish-docs.yml
index 82f07c3..a801a78 100644
--- a/.github/workflows/publish-docs.yml
+++ b/.github/workflows/publish-docs.yml
@@ -12,14 +12,14 @@ jobs:
build-and-deploy:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
with:
# fetch all tags so `versioneer` can properly determine current version
fetch-depth: 0
- name: Set up Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v5
with:
- python-version: '3.8'
+ python-version: '3.9'
- name: Install dependencies
run: python -m pip install .[dev]
- name: Build
diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml
index 51fbb67..d67fc8e 100644
--- a/.github/workflows/publish-package.yml
+++ b/.github/workflows/publish-package.yml
@@ -11,12 +11,12 @@ jobs:
deploy:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
with:
# fetch all tags so `versioneer` can properly determine current version
fetch-depth: 0
- name: Set up Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
index abf94e8..fa50759 100644
--- a/.github/workflows/python-package.yml
+++ b/.github/workflows/python-package.yml
@@ -10,17 +10,32 @@ on:
branches: [develop, main]
jobs:
+ lint-and-format:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python 3.9
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.9"
+ - name: Install dependencies
+ run: python -m pip install .[qa]
+ - name: Linting by ruff
+ run: ruff check
+ - name: Formatting by ruff
+ run: ruff format --check
+
build:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: [3.8, 3.9, '3.10', '3.11']
+ python-version: [3.9, '3.10', '3.11']
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 680cefa..d9fa300 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,11 +1,25 @@
repos:
- - repo: https://github.com/psf/black
- rev: 22.6.0
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.5.7
hooks:
- - id: black
- types: [file, python]
- language_version: python3.9
- - repo: https://github.com/pre-commit/mirrors-isort
- rev: v5.7.0
+ # Run the linter.
+ - id: ruff
+ args: [ --fix ]
+ # Run the formatter.
+ - id: ruff-format
+ types_or: [ python, jupyter ]
+ # # Mypy: Optional static type checking
+ # # https://github.com/pre-commit/mirrors-mypy
+ # - repo: https://github.com/pre-commit/mirrors-mypy
+ # rev: v1.11.1
+ # hooks:
+ # - id: mypy
+ # exclude: ^(docs|tests)\/
+ # language_version: python3.9
+ # args: [--namespace-packages, --explicit-package-bases, --ignore-missing-imports, --non-interactive, --install-types]
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.6.0
hooks:
- - id: isort
+ - id: trailing-whitespace
+ - id: debug-statements
+ - id: end-of-file-fixer
diff --git a/CODEOWNERS b/CODEOWNERS
index 53ed855..a61d6a5 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1 +1 @@
-* @fdosani @ak-gupta @jdawang @gladysteh99 @NikhilJArora
+* @fdosani @ak-gupta @jdawang @gladysteh99
diff --git a/README.rst b/README.rst
index c45e50a..1e73b24 100644
--- a/README.rst
+++ b/README.rst
@@ -13,7 +13,7 @@ A Python library to assist with ETL processing for:
In addition:
-- The library supports Python 3.8 to 3.10
+- The library supports Python 3.9 to 3.11
- DB Driver (Adapter) agnostic. Use your favourite driver that complies with
`DB-API 2.0 `_
- It provides functionality to download and upload data to S3 buckets, and internal stages (Snowflake)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 6e9f150..4315309 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# SPDX-Copyright: Copyright (c) Capital One Services, LLC
# SPDX-License-Identifier: Apache-2.0
# Copyright 2018 Capital One Services, LLC
@@ -27,8 +26,9 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
-import sys
import os
+import sys
+
import sphinx_rtd_theme
from locopy._version import __version__
@@ -151,7 +151,13 @@
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
- (master_doc, "locopy.tex", "locopy Documentation", "Faisal Dosani, Ian Robertson", "manual")
+ (
+ master_doc,
+ "locopy.tex",
+ "locopy Documentation",
+ "Faisal Dosani, Ian Robertson",
+ "manual",
+ )
]
@@ -185,4 +191,9 @@
intersphinx_mapping = {"https://docs.python.org/": None}
# autodoc
-autodoc_default_flags = ["members", "undoc-members", "show-inheritance", "inherited-members"]
+autodoc_default_flags = [
+ "members",
+ "undoc-members",
+ "show-inheritance",
+ "inherited-members",
+]
diff --git a/locopy/__init__.py b/locopy/__init__.py
index 091f07d..cf7c376 100644
--- a/locopy/__init__.py
+++ b/locopy/__init__.py
@@ -13,8 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""A Python library to assist with ETL processing."""
-from .database import Database
-from .redshift import Redshift
-from .s3 import S3
-from .snowflake import Snowflake
+from locopy.database import Database
+from locopy.redshift import Redshift
+from locopy.s3 import S3
+from locopy.snowflake import Snowflake
+
+__all__ = ["S3", "Database", "Redshift", "Snowflake"]
diff --git a/locopy/_version.py b/locopy/_version.py
index 5121c9b..c8554de 100644
--- a/locopy/_version.py
+++ b/locopy/_version.py
@@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "0.5.9"
+__version__ = "0.6.0"
diff --git a/locopy/database.py b/locopy/database.py
index 8564331..411d095 100644
--- a/locopy/database.py
+++ b/locopy/database.py
@@ -14,20 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Database Module
-"""
+"""Database Module."""
+
import time
-from .errors import CredentialsError, DBError
-from .logger import INFO, get_logger
-from .utility import read_config_yaml
+from locopy.errors import CredentialsError, DBError
+from locopy.logger import INFO, get_logger
+from locopy.utility import read_config_yaml
logger = get_logger(__name__, INFO)
-class Database(object):
- """This is the base class for all DBAPI 2 database connectors which will inherit this
- functionality. The ``Database`` class will manage connections and handle executing queries.
+class Database:
+ """Base class for all DBAPI 2 database connectors which will inherit this functionality.
+
+ The ``Database`` class will manage connections and handle executing queries.
Most of the functionality should work out of the box for classes which inherit minus the
abstract method for ``connect`` which may vary across databases.
@@ -72,13 +73,16 @@ def __init__(self, dbapi, config_yaml=None, **kwargs):
self.cursor = None
if config_yaml and self.connection:
- raise CredentialsError("Please provide kwargs or a YAML configuraton, not both.")
+ raise CredentialsError(
+ "Please provide kwargs or a YAML configuraton, not both."
+ )
if config_yaml:
self.connection = read_config_yaml(config_yaml)
def connect(self):
- """Creates a connection to a database by setting the values of the ``conn`` and ``cursor``
- attributes.
+ """Create a connection to a database.
+
+ Sets the values of the ``conn`` and ``cursor`` attributes.
Raises
------
@@ -90,11 +94,12 @@ def connect(self):
self.cursor = self.conn.cursor()
except Exception as e:
logger.error("Error connecting to the database. err: %s", e)
- raise DBError("Error connecting to the database.")
+ raise DBError("Error connecting to the database.") from e
def disconnect(self):
- """Terminates the connection by closing the values of the ``conn`` and ``cursor``
- attributes.
+ """Terminate the connection.
+
+ Closes the values of the ``conn`` and ``cursor`` attributes.
Raises
------
@@ -108,7 +113,9 @@ def disconnect(self):
self.conn.close()
except Exception as e:
logger.error("Error disconnecting from the database. err: %s", e)
- raise DBError("There is a problem disconnecting from the database.")
+ raise DBError(
+ "There is a problem disconnecting from the database."
+ ) from e
else:
logger.info("No connection to close")
@@ -153,7 +160,7 @@ def execute(self, sql, commit=True, params=(), many=False, verbose=True):
self.cursor.execute(sql, params)
except Exception as e:
logger.error("Error running SQL query. err: %s", e)
- raise DBError("Error running SQL query.")
+ raise DBError("Error running SQL query.") from e
if commit:
self.conn.commit()
elapsed = time.time() - start_time
@@ -167,8 +174,9 @@ def execute(self, sql, commit=True, params=(), many=False, verbose=True):
raise DBError("Cannot execute SQL on a closed connection.")
def column_names(self):
- """Pull column names out of the cursor description. Depending on the
- DBAPI, it could return column names as bytes: ``b'column_name'``
+ """Pull column names out of the cursor description.
+
+ Depending on the DBAPI, it could return column names as bytes: ``b'column_name'``.
Returns
-------
@@ -177,12 +185,13 @@ def column_names(self):
"""
try:
return [column[0].decode().lower() for column in self.cursor.description]
- except:
+ except Exception:
return [column[0].lower() for column in self.cursor.description]
def to_dataframe(self, size=None):
- """Return a dataframe of the last query results. This imports Pandas
- in here, so that it's not needed for other use cases. This is just a
+ """Return a dataframe of the last query results.
+
+ This imports Pandas in here, so that it's not needed for other use cases. This is just a
convenience method.
Parameters
@@ -214,7 +223,7 @@ def to_dataframe(self, size=None):
return pandas.DataFrame(fetched, columns=columns)
def to_dict(self):
- """Generate dictionaries of rows
+ """Generate dictionaries of rows.
Yields
------
@@ -226,7 +235,7 @@ def to_dict(self):
yield dict(zip(columns, row))
def _is_connected(self):
- """Checks the connection and cursor class arrtribues are initalized.
+ """Check the connection and cursor class arrtributes are initalized.
Returns
-------
@@ -235,16 +244,18 @@ def _is_connected(self):
"""
try:
return self.conn is not None and self.cursor is not None
- except:
+ except Exception:
return False
def __enter__(self):
+ """Open the connection."""
logger.info("Connecting...")
self.connect()
logger.info("Connection established.")
return self
def __exit__(self, exc_type, exc, exc_tb):
+ """Close the connection."""
logger.info("Closing connection...")
self.disconnect()
logger.info("Connection closed.")
diff --git a/locopy/errors.py b/locopy/errors.py
index ed1f704..692d352 100644
--- a/locopy/errors.py
+++ b/locopy/errors.py
@@ -13,81 +13,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""Custom errors for locopy."""
class LocopyError(Exception):
- """
- Baseclass for all Locopy errors.
- """
+ """Baseclass for all Locopy errors."""
class CompressionError(LocopyError):
- """
- Raised when there is an error compressing a file.
- """
+ """Raised when there is an error compressing a file."""
class LocopySplitError(LocopyError):
- """
- Raised when there is an error splitting a file.
- """
+ """Raised when there is an error splitting a file."""
class LocopyIgnoreHeaderError(LocopyError):
- """
- Raised when Multiple IGNOREHEADERS are found in copy options.
- """
+ """Raised when Multiple IGNOREHEADERS are found in copy options."""
class LocopyConcatError(LocopyError):
- """
- Raised when there is an error concatenating files.
- """
+ """Raised when there is an error concatenating files."""
class DBError(Exception):
- """
- Base class for all Database errors.
- """
+ """Base class for all Database errors."""
class CredentialsError(DBError):
- """
- Raised when the users credentials are not provided.
- """
+ """Raised when the users credentials are not provided."""
class S3Error(Exception):
- """
- Base class for all S3 errors.
- """
+ """Base class for all S3 errors."""
class S3CredentialsError(S3Error):
- """
- Raised when there is an error with AWS credentials.
- """
+ """Raised when there is an error with AWS credentials."""
class S3InitializationError(S3Error):
- """
- Raised when there is an error initializing S3 client.
- """
+ """Raised when there is an error initializing S3 client."""
class S3UploadError(S3Error):
- """
- Raised when there is an upload error to S3.
- """
+ """Raised when there is an upload error to S3."""
class S3DownloadError(S3Error):
- """
- Raised when there is an download error to S3.
- """
+ """Raised when there is an download error to S3."""
class S3DeletionError(S3Error):
- """
- Raised when there is an deletion error on S3.
- """
+ """Raised when there is an deletion error on S3."""
diff --git a/locopy/logger.py b/locopy/logger.py
index 2614124..c07beb4 100644
--- a/locopy/logger.py
+++ b/locopy/logger.py
@@ -14,35 +14,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Logging Module
-Module which setsup the basic logging infrustrcuture for the application
+"""Logging Module.
+
+Module which sets up the basic logging infrustrcuture for the application.
"""
+
import logging
import sys
# logger formating
BRIEF_FORMAT = "%(levelname)s %(asctime)s - %(name)s: %(message)s"
VERBOSE_FORMAT = (
- "%(levelname)s|%(asctime)s|%(name)s|%(filename)s|" "%(funcName)s|%(lineno)d: %(message)s"
+ "%(levelname)s|%(asctime)s|%(name)s|%(filename)s|"
+ "%(funcName)s|%(lineno)d: %(message)s"
)
FORMAT_TO_USE = VERBOSE_FORMAT
# logger levels
DEBUG = logging.DEBUG
INFO = logging.INFO
-WARN = logging.WARN
+WARN = logging.WARNING
ERROR = logging.ERROR
CRITICAL = logging.CRITICAL
def get_logger(name=None, log_level=logging.DEBUG):
- """Sets the basic logging features for the application
+ """Set the basic logging features for the application.
+
Parameters
----------
name : str, optional
The name of the logger. Defaults to ``None``
log_level : int, optional
The logging level. Defaults to ``logging.INFO``
+
Returns
-------
logging.Logger
diff --git a/locopy/redshift.py b/locopy/redshift.py
index eeabead..84527f4 100644
--- a/locopy/redshift.py
+++ b/locopy/redshift.py
@@ -14,26 +14,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Redshift Module
+"""Redshift Module.
+
Module to wrap a database adapter into a Redshift class which can be used to connect
to Redshift, and run arbitrary code.
"""
+
import os
from pathlib import Path
-from .database import Database
-from .errors import DBError, S3CredentialsError
-from .logger import INFO, get_logger
-from .s3 import S3
-from .utility import (compress_file_list, concatenate_files, find_column_type,
- get_ignoreheader_number, split_file, write_file)
+from locopy.database import Database
+from locopy.errors import DBError, S3CredentialsError
+from locopy.logger import INFO, get_logger
+from locopy.s3 import S3
+from locopy.utility import (
+ compress_file_list,
+ concatenate_files,
+ find_column_type,
+ get_ignoreheader_number,
+ split_file,
+ write_file,
+)
logger = get_logger(__name__, INFO)
def add_default_copy_options(copy_options=None):
- """Adds in default options for the ``COPY`` job, unless those specific
- options have been provided in the request.
+ """Add in default options for the ``COPY`` job.
+
+ Unless those specific options have been provided in the request.
Parameters
----------
@@ -58,8 +67,9 @@ def add_default_copy_options(copy_options=None):
def combine_copy_options(copy_options):
- """Returns the ``copy_options`` attribute with spaces in between and as
- a string.
+ """Return the ``copy_options`` attribute with spaces in between.
+
+ Converts to a string.
Parameters
----------
@@ -76,8 +86,9 @@ def combine_copy_options(copy_options):
class Redshift(S3, Database):
- """Locopy class which manages connections to Redshift. Inherits ``Database`` and implements the
- specific ``COPY`` and ``UNLOAD`` functionality.
+ """Locopy class which manages connections to Redshift.
+
+ Inherits ``Database`` and implements the specific ``COPY`` and ``UNLOAD`` functionality.
If any of host, port, dbname, user and password are not provided, a config_yaml file must be
provided with those parameters in it. Please note ssl is always enforced when connecting.
@@ -159,8 +170,9 @@ def __init__(
Database.__init__(self, dbapi, config_yaml, **kwargs)
def connect(self):
- """Creates a connection to the Redshift cluster by
- setting the values of the ``conn`` and ``cursor`` attributes.
+ """Create a connection to the Redshift cluster.
+
+ Sets the values of the ``conn`` and ``cursor`` attributes.
Raises
------
@@ -171,11 +183,10 @@ def connect(self):
self.connection["sslmode"] = "require"
elif self.dbapi.__name__ == "pg8000":
self.connection["ssl_context"] = True
- super(Redshift, self).connect()
+ super().connect()
def copy(self, table_name, s3path, delim="|", copy_options=None):
- """Executes the COPY command to load files from S3 into
- a Redshift table.
+ """Execute the COPY command to load files from S3 into a Redshift table.
Parameters
----------
@@ -200,10 +211,10 @@ def copy(self, table_name, s3path, delim="|", copy_options=None):
"""
if not self._is_connected():
raise DBError("No Redshift connection object is present.")
- if copy_options and "PARQUET" not in copy_options or copy_options is None:
+ if (copy_options and "PARQUET" not in copy_options) or copy_options is None:
copy_options = add_default_copy_options(copy_options)
if delim:
- copy_options = [f"DELIMITER '{delim}'"] + copy_options
+ copy_options = [f"DELIMITER '{delim}'", *copy_options]
copy_options_text = combine_copy_options(copy_options)
base_copy_string = "COPY {0} FROM '{1}' " "CREDENTIALS '{2}' " "{3};"
try:
@@ -214,7 +225,7 @@ def copy(self, table_name, s3path, delim="|", copy_options=None):
except Exception as e:
logger.error("Error running COPY on Redshift. err: %s", e)
- raise DBError("Error running COPY on Redshift.")
+ raise DBError("Error running COPY on Redshift.") from e
def load_and_copy(
self,
@@ -228,8 +239,9 @@ def load_and_copy(
compress=True,
s3_folder=None,
):
- """Loads a file to S3, then copies into Redshift. Has options to
- split a single file into multiple files, compress using gzip, and
+ r"""Load a file to S3, then copies into Redshift.
+
+ Has options to split a single file into multiple files, compress using gzip, and
upload to an S3 bucket with folders within the bucket.
Notes
@@ -341,8 +353,9 @@ def unload_and_copy(
parallel_off=False,
unload_options=None,
):
- """``UNLOAD`` data from Redshift, with options to write to a flat file,
- and store on S3.
+ """Unload data from Redshift.
+
+ With options to write to a flat file and store on S3.
Parameters
----------
@@ -390,27 +403,27 @@ def unload_and_copy(
# data = []
s3path = self._generate_unload_path(s3_bucket, s3_folder)
- ## configure unload options
+ # configure unload options
if unload_options is None:
unload_options = []
if delim:
- unload_options.append("DELIMITER '{0}'".format(delim))
+ unload_options.append(f"DELIMITER '{delim}'")
if parallel_off:
unload_options.append("PARALLEL OFF")
- ## run unload
+ # run unload
self.unload(query=query, s3path=s3path, unload_options=unload_options)
- ## parse unloaded files
+ # parse unloaded files
s3_download_list = self._unload_generated_files()
if s3_download_list is None:
logger.error("No files generated from unload")
- raise Exception("No files generated from unload")
+ raise DBError("No files generated from unload")
columns = self._get_column_names(query)
if columns is None:
logger.error("Unable to retrieve column names from exported data")
- raise Exception("Unable to retrieve column names from exported data.")
+ raise DBError("Unable to retrieve column names from exported data.")
# download files locally with same name
local_list = self.download_list_from_s3(s3_download_list, raw_unload_path)
@@ -423,8 +436,7 @@ def unload_and_copy(
self.delete_list_from_s3(s3_download_list)
def unload(self, query, s3path, unload_options=None):
- """Executes the UNLOAD command to export a query from
- Redshift to S3.
+ """Execute the UNLOAD command to export a query from Redshift to S3.
Parameters
----------
@@ -462,10 +474,10 @@ def unload(self, query, s3path, unload_options=None):
self.execute(sql, commit=True)
except Exception as e:
logger.error("Error running UNLOAD on redshift. err: %s", e)
- raise DBError("Error running UNLOAD on redshift.")
+ raise DBError("Error running UNLOAD on redshift.") from e
def _get_column_names(self, query):
- """Gets a list of column names from the supplied query.
+ """Get a list of column names from the supplied query.
Parameters
----------
@@ -477,22 +489,21 @@ def _get_column_names(self, query):
list
List of column names. Returns None if no columns were retrieved.
"""
-
try:
logger.info("Retrieving column names")
- sql = "SELECT * FROM ({}) WHERE 1 = 0".format(query)
+ sql = f"SELECT * FROM ({query}) WHERE 1 = 0"
self.execute(sql)
- results = [desc for desc in self.cursor.description]
+ results = list(self.cursor.description)
if len(results) > 0:
return [result[0].strip() for result in results]
else:
return None
- except Exception as e:
+ except Exception:
logger.error("Error retrieving column names")
raise
def _unload_generated_files(self):
- """Gets a list of files generated by the unload process
+ """Get a list of files generated by the unload process.
Returns
-------
@@ -511,7 +522,7 @@ def _unload_generated_files(self):
return [result[0].strip() for result in results]
else:
return None
- except Exception as e:
+ except Exception:
logger.error("Error retrieving unloads generated files")
raise
@@ -556,7 +567,6 @@ def insert_dataframe_to_table(
"""
-
import pandas as pd
if columns:
@@ -585,9 +595,7 @@ def insert_dataframe_to_table(
+ ")"
)
column_sql = "(" + ",".join(list(metadata.keys())) + ")"
- create_query = "CREATE TABLE {table_name} {create_join}".format(
- table_name=table_name, create_join=create_join
- )
+ create_query = f"CREATE TABLE {table_name} {create_join}"
self.execute(create_query)
logger.info("New table has been created")
@@ -611,9 +619,7 @@ def insert_dataframe_to_table(
to_insert.append(none_row)
string_join = ", ".join(to_insert)
insert_query = (
- """INSERT INTO {table_name} {columns} VALUES {values}""".format(
- table_name=table_name, columns=column_sql, values=string_join
- )
+ f"""INSERT INTO {table_name} {column_sql} VALUES {string_join}"""
)
self.execute(insert_query, verbose=verbose)
logger.info("Table insertion has completed")
diff --git a/locopy/s3.py b/locopy/s3.py
index 348cf4b..2229118 100644
--- a/locopy/s3.py
+++ b/locopy/s3.py
@@ -14,17 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""S3 Module
+"""S3 Module.
+
Module to wrap the boto3 api usage and provide functionality to manage
-multipart upload to S3 buckets
+multipart upload to S3 buckets.
"""
+
import os
from boto3 import Session
from boto3.s3.transfer import TransferConfig
from botocore.client import Config
-from .errors import (
+from locopy.errors import (
S3CredentialsError,
S3DeletionError,
S3DownloadError,
@@ -32,16 +34,16 @@
S3InitializationError,
S3UploadError,
)
-from .logger import INFO, get_logger
-from .utility import ProgressPercentage
+from locopy.logger import INFO, get_logger
+from locopy.utility import ProgressPercentage
logger = get_logger(__name__, INFO)
-class S3(object):
- """
- S3 wrapper class which utilizes the boto3 library to push files to an S3
- bucket.
+class S3:
+ """S3 wrapper class.
+
+ Utilizes the boto3 library to push files to an S3 bucket.
Parameters
----------
@@ -85,7 +87,6 @@ class S3(object):
"""
def __init__(self, profile=None, kms_key=None, **kwargs):
-
self.profile = profile
self.kms_key = kms_key
self.session = None
@@ -100,7 +101,7 @@ def _set_session(self):
logger.info("Initialized AWS session.")
except Exception as e:
logger.error("Error initializing AWS Session, err: %s", e)
- raise S3Error("Error initializing AWS Session.")
+ raise S3Error("Error initializing AWS Session.") from e
credentials = self.session.get_credentials()
if credentials is None:
raise S3CredentialsError("Credentials could not be set.")
@@ -111,11 +112,12 @@ def _set_client(self):
logger.info("Successfully initialized S3 client.")
except Exception as e:
logger.error("Error initializing S3 Client, err: %s", e)
- raise S3InitializationError("Error initializing S3 Client.")
+ raise S3InitializationError("Error initializing S3 Client.") from e
def _credentials_string(self):
- """Returns a credentials string for the Redshift COPY or UNLOAD command,
- containing credentials from the current session.
+ """Return a credentials string for the Redshift COPY or UNLOAD command.
+
+ Containing credentials from the current session.
Returns
-------
@@ -131,7 +133,7 @@ def _credentials_string(self):
return temp.format(creds.access_key, creds.secret_key)
def _generate_s3_path(self, bucket, key):
- """Will return the S3 file URL in the format S3://bucket/key
+ """Will return the S3 file URL in the format S3://bucket/key.
Parameters
----------
@@ -146,11 +148,13 @@ def _generate_s3_path(self, bucket, key):
str
string of the S3 file URL in the format S3://bucket/key
"""
- return "s3://{0}/{1}".format(bucket, key)
+ return f"s3://{bucket}/{key}"
def _generate_unload_path(self, bucket, folder):
- """Will return the S3 file URL in the format s3://bucket/folder if a
- valid (not None) folder is provided. Otherwise, returns s3://bucket
+ """Return the S3 file URL.
+
+ If a valid (not None) folder is provided, returnsin the format s3://bucket/folder.
+ Otherwise, returns s3://bucket.
Parameters
----------
@@ -168,9 +172,9 @@ def _generate_unload_path(self, bucket, folder):
If folder is None, returns format s3://bucket
"""
if folder:
- s3_path = "s3://{0}/{1}".format(bucket, folder)
+ s3_path = f"s3://{bucket}/{folder}"
else:
- s3_path = "s3://{0}".format(bucket)
+ s3_path = f"s3://{bucket}"
return s3_path
def upload_to_s3(self, local, bucket, key):
@@ -204,13 +208,19 @@ def upload_to_s3(self, local, bucket, key):
extra_args["SSEKMSKeyId"] = self.kms_key
logger.info("Using KMS Keys for encryption")
- logger.info("Uploading file to S3 bucket: %s", self._generate_s3_path(bucket, key))
+ logger.info(
+ "Uploading file to S3 bucket: %s", self._generate_s3_path(bucket, key)
+ )
self.s3.upload_file(
- local, bucket, key, ExtraArgs=extra_args, Callback=ProgressPercentage(local)
+ local,
+ bucket,
+ key,
+ ExtraArgs=extra_args,
+ Callback=ProgressPercentage(local),
)
except Exception as e:
logger.error("Error uploading to S3. err: %s", e)
- raise S3UploadError("Error uploading to S3.")
+ raise S3UploadError("Error uploading to S3.") from e
def upload_list_to_s3(self, local_list, bucket, folder=None):
"""
@@ -275,13 +285,14 @@ def download_from_s3(self, bucket, key, local):
"""
try:
logger.info(
- "Downloading file from S3 bucket: %s", self._generate_s3_path(bucket, key),
+ "Downloading file from S3 bucket: %s",
+ self._generate_s3_path(bucket, key),
)
config = TransferConfig(max_concurrency=5)
self.s3.download_file(bucket, key, local, Config=config)
except Exception as e:
logger.error("Error downloading from S3. err: %s", e)
- raise S3DownloadError("Error downloading from S3.")
+ raise S3DownloadError("Error downloading from S3.") from e
def download_list_from_s3(self, s3_list, local_path=None):
"""
@@ -330,11 +341,13 @@ def delete_from_s3(self, bucket, key):
If there is a issue deleting from the S3 bucket
"""
try:
- logger.info("Deleting file from S3 bucket: %s", self._generate_s3_path(bucket, key))
+ logger.info(
+ "Deleting file from S3 bucket: %s", self._generate_s3_path(bucket, key)
+ )
self.s3.delete_object(Bucket=bucket, Key=key)
except Exception as e:
logger.error("Error deleting from S3. err: %s", e)
- raise S3DeletionError("Error deleting from S3.")
+ raise S3DeletionError("Error deleting from S3.") from e
def delete_list_from_s3(self, s3_list):
"""
@@ -351,9 +364,7 @@ def delete_list_from_s3(self, s3_list):
self.delete_from_s3(s3_bucket, s3_key)
def parse_s3_url(self, s3_url):
- """
- Parse a string of the s3 url to extract the bucket and key.
- scheme or not.
+ """Extract the bucket and key from a s3 url.
Parameters
----------
diff --git a/locopy/snowflake.py b/locopy/snowflake.py
index 60d3c5e..81fc842 100644
--- a/locopy/snowflake.py
+++ b/locopy/snowflake.py
@@ -14,18 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Snowflake Module
+"""Snowflake Module.
+
Module to wrap a database adapter into a Snowflake class which can be used to connect
to Snowflake, and run arbitrary code.
"""
+
import os
from pathlib import PurePath
-from .database import Database
-from .errors import DBError, S3CredentialsError
-from .logger import INFO, get_logger
-from .s3 import S3
-from .utility import find_column_type
+from locopy.database import Database
+from locopy.errors import DBError, S3CredentialsError
+from locopy.logger import INFO, get_logger
+from locopy.s3 import S3
+from locopy.utility import find_column_type
logger = get_logger(__name__, INFO)
@@ -85,8 +87,9 @@
def combine_options(options=None):
- """Returns the ``copy_options`` or ``format_options`` attribute with spaces in between and as
- a string. If options is ``None`` then return an empty string.
+ """Return the ``copy_options`` or ``format_options`` attribute.
+
+ With spaces in between and as a string. If options is ``None`` then return an empty string.
Parameters
----------
@@ -103,8 +106,9 @@ def combine_options(options=None):
class Snowflake(S3, Database):
- """Locopy class which manages connections to Snowflake. Inherits ``Database`` and implements
- the specific ``COPY INTO`` functionality.
+ """Locopy class which manages connections to Snowflake. Inherits ``Database``.
+
+ Implements the specific ``COPY INTO`` functionality.
Parameters
----------
@@ -183,22 +187,23 @@ def __init__(
Database.__init__(self, dbapi, config_yaml, **kwargs)
def connect(self):
- """Creates a connection to the Snowflake cluster by
- setting the values of the ``conn`` and ``cursor`` attributes.
+ """Create a connection to the Snowflake cluster.
+
+ Setg the values of the ``conn`` and ``cursor`` attributes.
Raises
------
DBError
If there is a problem establishing a connection to Snowflake.
"""
- super(Snowflake, self).connect()
+ super().connect()
if self.connection.get("warehouse") is not None:
- self.execute("USE WAREHOUSE {0}".format(self.connection["warehouse"]))
+ self.execute("USE WAREHOUSE {}".format(self.connection["warehouse"]))
if self.connection.get("database") is not None:
- self.execute("USE DATABASE {0}".format(self.connection["database"]))
+ self.execute("USE DATABASE {}".format(self.connection["database"]))
if self.connection.get("schema") is not None:
- self.execute("USE SCHEMA {0}".format(self.connection["schema"]))
+ self.execute("USE SCHEMA {}".format(self.connection["schema"]))
def upload_to_internal(
self, local, stage, parallel=4, auto_compress=True, overwrite=True
@@ -231,9 +236,7 @@ def upload_to_internal(
"""
local_uri = PurePath(local).as_posix()
self.execute(
- "PUT 'file://{0}' {1} PARALLEL={2} AUTO_COMPRESS={3} OVERWRITE={4}".format(
- local_uri, stage, parallel, auto_compress, overwrite
- )
+ f"PUT 'file://{local_uri}' {stage} PARALLEL={parallel} AUTO_COMPRESS={auto_compress} OVERWRITE={overwrite}"
)
def download_from_internal(self, stage, local=None, parallel=10):
@@ -255,16 +258,15 @@ def download_from_internal(self, stage, local=None, parallel=10):
if local is None:
local = os.getcwd()
local_uri = PurePath(local).as_posix()
- self.execute(
- "GET {0} 'file://{1}' PARALLEL={2}".format(stage, local_uri, parallel)
- )
+ self.execute(f"GET {stage} 'file://{local_uri}' PARALLEL={parallel}")
def copy(
self, table_name, stage, file_type="csv", format_options=None, copy_options=None
):
- """Executes the ``COPY INTO
`` command to load CSV files from a stage into
- a Snowflake table. If ``file_type == csv`` and ``format_options == None``, ``format_options``
- will default to: ``["FIELD_DELIMITER='|'", "SKIP_HEADER=0"]``
+ """Load files from a stage into a Snowflake table.
+
+ Execute the ``COPY INTO `` command to If ``file_type == csv`` and ``format_options == None``, ``format_options``
+ will default to: ``["FIELD_DELIMITER='|'", "SKIP_HEADER=0"]``.
Parameters
----------
@@ -296,9 +298,7 @@ def copy(
if file_type not in COPY_FORMAT_OPTIONS:
raise ValueError(
- "Invalid file_type. Must be one of {0}".format(
- list(COPY_FORMAT_OPTIONS.keys())
- )
+ f"Invalid file_type. Must be one of {list(COPY_FORMAT_OPTIONS.keys())}"
)
if format_options is None and file_type == "csv":
@@ -317,7 +317,7 @@ def copy(
except Exception as e:
logger.error("Error running COPY on Snowflake. err: %s", e)
- raise DBError("Error running COPY on Snowflake.")
+ raise DBError("Error running COPY on Snowflake.") from e
def unload(
self,
@@ -328,9 +328,12 @@ def unload(
header=False,
copy_options=None,
):
- """Executes the ``COPY INTO `` command to export a query/table from
- Snowflake to a stage. If ``file_type == csv`` and ``format_options == None``, ``format_options``
- will default to: ``["FIELD_DELIMITER='|'"]``
+ """Export a query/table from Snowflake to a stage.
+
+ Execute the ``COPY INTO `` command.
+
+ If ``file_type == csv`` and ``format_options == None``, ``format_options``
+ will default to: ``["FIELD_DELIMITER='|'"]``.
Parameters
----------
@@ -364,9 +367,7 @@ def unload(
if file_type not in COPY_FORMAT_OPTIONS:
raise ValueError(
- "Invalid file_type. Must be one of {0}".format(
- list(UNLOAD_FORMAT_OPTIONS.keys())
- )
+ f"Invalid file_type. Must be one of {list(UNLOAD_FORMAT_OPTIONS.keys())}"
)
if format_options is None and file_type == "csv":
@@ -390,13 +391,14 @@ def unload(
self.execute(sql, commit=True)
except Exception as e:
logger.error("Error running UNLOAD on Snowflake. err: %s", e)
- raise DBError("Error running UNLOAD on Snowflake.")
+ raise DBError("Error running UNLOAD on Snowflake.") from e
def insert_dataframe_to_table(
self, dataframe, table_name, columns=None, create=False, metadata=None
):
- """
- Insert a Pandas dataframe to an existing table or a new table. In newer versions of the
+ """Insert a Pandas dataframe to an existing table or a new table.
+
+ In newer versions of the
python snowflake connector (v2.1.2+) users can call the ``write_pandas`` method from the cursor
directly, ``insert_dataframe_to_table`` is a custom implementation and does not use
``write_pandas``. Instead of using ``COPY INTO`` the method builds a list of tuples to
@@ -421,7 +423,6 @@ def insert_dataframe_to_table(
metadata: dictionary, optional
If metadata==None, it will be generated based on data
"""
-
import pandas as pd
if columns:
@@ -434,7 +435,7 @@ def insert_dataframe_to_table(
# create a list of tuples for insert
to_insert = []
for row in dataframe.itertuples(index=False):
- none_row = tuple([None if pd.isnull(val) else str(val) for val in row])
+ none_row = tuple(None if pd.isnull(val) else str(val) for val in row)
to_insert.append(none_row)
if not create and metadata:
@@ -457,22 +458,20 @@ def insert_dataframe_to_table(
+ ")"
)
column_sql = "(" + ",".join(list(metadata.keys())) + ")"
- create_query = "CREATE TABLE {table_name} {create_join}".format(
- table_name=table_name, create_join=create_join
- )
+ create_query = f"CREATE TABLE {table_name} {create_join}"
self.execute(create_query)
logger.info("New table has been created")
- insert_query = """INSERT INTO {table_name} {columns} VALUES {values}""".format(
- table_name=table_name, columns=column_sql, values=string_join
- )
+ insert_query = f"""INSERT INTO {table_name} {column_sql} VALUES {string_join}"""
logger.info("Inserting records...")
self.execute(insert_query, params=to_insert, many=True)
logger.info("Table insertion has completed")
def to_dataframe(self, size=None):
- """Return a dataframe of the last query results. This is just a convenience method. This
+ """Return a dataframe of the last query results.
+
+ This is just a convenience method. This
method overrides the base classes implementation in favour for the snowflake connectors
built-in ``fetch_pandas_all`` when ``size==None``. If ``size != None`` then we will continue
to use the existing functionality where we iterate through the cursor and build the
@@ -492,4 +491,4 @@ def to_dataframe(self, size=None):
if size is None and self.cursor._query_result_format == "arrow":
return self.cursor.fetch_pandas_all()
else:
- return super(Snowflake, self).to_dataframe(size)
+ return super().to_dataframe(size)
diff --git a/locopy/utility.py b/locopy/utility.py
index a903b41..67e8253 100644
--- a/locopy/utility.py
+++ b/locopy/utility.py
@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Utility Module
-Module which utility functions for use within the application
+"""Utility Module.
+
+Module which utility functions for use within the application.
"""
+
import gzip
import os
import shutil
@@ -27,9 +29,14 @@
import yaml
-from .errors import (CompressionError, CredentialsError, LocopyConcatError,
- LocopyIgnoreHeaderError, LocopySplitError)
-from .logger import INFO, get_logger
+from locopy.errors import (
+ CompressionError,
+ CredentialsError,
+ LocopyConcatError,
+ LocopyIgnoreHeaderError,
+ LocopySplitError,
+)
+from locopy.logger import INFO, get_logger
logger = get_logger(__name__, INFO)
@@ -63,7 +70,7 @@ def write_file(data, delimiter, filepath, mode="w"):
def compress_file(input_file, output_file):
- """Compresses a file (gzip)
+ """Compresses a file (gzip).
Parameters
----------
@@ -73,17 +80,16 @@ def compress_file(input_file, output_file):
Path to write the compressed file
"""
try:
- with open(input_file, "rb") as f_in:
- with gzip.open(output_file, "wb") as f_out:
- logger.info("compressing (gzip): %s to %s", input_file, output_file)
- shutil.copyfileobj(f_in, f_out)
+ with open(input_file, "rb") as f_in, gzip.open(output_file, "wb") as f_out:
+ logger.info("compressing (gzip): %s to %s", input_file, output_file)
+ shutil.copyfileobj(f_in, f_out)
except Exception as e:
logger.error("Error compressing the file. err: %s", e)
- raise CompressionError("Error compressing the file.")
+ raise CompressionError("Error compressing the file.") from e
def compress_file_list(file_list):
- """Compresses a list of files (gzip) and clean up the old files
+ """Compresses a list of files (gzip) and clean up the old files.
Parameters
----------
@@ -97,7 +103,7 @@ def compress_file_list(file_list):
gz appended)
"""
for i, f in enumerate(file_list):
- gz = "{0}.gz".format(f)
+ gz = f"{f}.gz"
compress_file(f, gz)
file_list[i] = gz
os.remove(f) # cleanup old files
@@ -149,7 +155,7 @@ def split_file(input_file, output_file, splits=1, ignore_header=0):
cpool = cycle(pool)
logger.info("splitting file: %s into %s files", input_file, splits)
# open output file handlers
- files = [open("{0}.{1}".format(output_file, x), "wb") for x in pool]
+ files = [open(f"{output_file}.{x}", "wb") for x in pool] # noqa: SIM115
# open input file and send line to different handler
with open(input_file, "rb") as f_in:
# if we have a value in ignore_header then skip those many lines to start
@@ -168,7 +174,7 @@ def split_file(input_file, output_file, splits=1, ignore_header=0):
for x in pool:
files[x].close()
os.remove(files[x].name)
- raise LocopySplitError("Error splitting the file.")
+ raise LocopySplitError("Error splitting the file.") from e
def concatenate_files(input_list, output_file, remove=True):
@@ -202,13 +208,15 @@ def concatenate_files(input_list, output_file, remove=True):
os.remove(f)
except Exception as e:
logger.error("Error concateneating files. err: %s", e)
- raise LocopyConcatError("Error concateneating files.")
+ raise LocopyConcatError("Error concateneating files.") from e
def read_config_yaml(config_yaml):
- """
- Reads a configuration YAML file to populate the database
- connection attributes, and validate required ones. Example::
+ """Read a configuration YAML file.
+
+ Populate the database connection attributes, and validate required ones.
+
+ Example::
host: my.redshift.cluster.com
port: 5439
@@ -240,7 +248,7 @@ def read_config_yaml(config_yaml):
locopy_yaml = yaml.safe_load(config_yaml)
except Exception as e:
logger.error("Error reading yaml. err: %s", e)
- raise CredentialsError("Error reading yaml.")
+ raise CredentialsError("Error reading yaml.") from e
return locopy_yaml
@@ -275,7 +283,6 @@ def find_column_type(dataframe, warehouse_type: str):
A dictionary of columns with their data type
"""
import re
- from datetime import date, datetime
import pandas as pd
@@ -313,7 +320,9 @@ def validate_float_object(column):
data = dataframe[column].dropna().reset_index(drop=True)
if data.size == 0:
column_type.append("varchar")
- elif (data.dtype in ["datetime64[ns]", "M8[ns]"]) or (re.match("(datetime64\[ns\,\W)([a-zA-Z]+)(\])",str(data.dtype))):
+ elif (data.dtype in ["datetime64[ns]", "M8[ns]"]) or (
+ re.match(r"(datetime64\[ns\,\W)([a-zA-Z]+)(\])", str(data.dtype))
+ ):
column_type.append("timestamp")
elif str(data.dtype).lower().startswith("bool"):
column_type.append("boolean")
@@ -333,20 +342,22 @@ def validate_float_object(column):
return OrderedDict(zip(list(dataframe.columns), column_type))
-class ProgressPercentage(object):
- """
- ProgressPercentage class is used by the S3Transfer upload_file callback
+class ProgressPercentage:
+ """ProgressPercentage class is used by the S3Transfer upload_file callback.
+
Please see the following url for more information:
- http://boto3.readthedocs.org/en/latest/reference/customizations/s3.html#ref-s3transfer-usage
+ http://boto3.readthedocs.org/en/latest/reference/customizations/s3.html#ref-s3transfer-usage.
"""
def __init__(self, filename):
- """
- Initiate the ProgressPercentage class, using the base information which
- makes up a pipeline
- Args:
- filename (str): A name of the file which we will monitor the
- progress of
+ """Initiate the ProgressPercentage class.
+
+ Using the base information which makes up a pipeline
+
+ Parameters
+ ----------
+ filename (str): A name of the file which we will monitor the
+ progress of.
"""
self._filename = filename
self._size = float(os.path.getsize(filename))
@@ -354,13 +365,15 @@ def __init__(self, filename):
self._lock = threading.Lock()
def __call__(self, bytes_amount):
- # To simplify we'll assume this is hooked up
- # to a single filename.
+ """Call as a function.
+
+ To simplify we'll assume this is hooked up to a single filename.
+ """
with self._lock:
self._seen_so_far += bytes_amount
percentage = (self._seen_so_far / self._size) * 100
sys.stdout.write(
- "\rTransfering [{0}] {1:.2f}%".format(
+ "\rTransfering [{}] {:.2f}%".format(
"#" * int(percentage / 10), percentage
)
)
@@ -368,9 +381,9 @@ def __call__(self, bytes_amount):
def get_ignoreheader_number(options):
- """
- Return the ``number_rows`` from ``IGNOREHEADER [ AS ] number_rows`` This doesn't not validate
- that the ``AS`` is valid.
+ """Return the ``number_rows`` from ``IGNOREHEADER [ AS ] number_rows``.
+
+ This doesn't validate that the ``AS`` is valid.
Parameters
----------
diff --git a/pyproject.toml b/pyproject.toml
index ef3555f..ffefeeb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,9 +6,9 @@ authors = [
{ name="Faisal Dosani", email="faisal.dosani@capitalone.com" },
]
license = {text = "Apache Software License"}
-dependencies = ["boto3<=1.34.126,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.2,>=0.25.2", "numpy<=1.26.4,>=1.22.0"]
+dependencies = ["boto3<=1.34.157,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.2,>=0.25.2", "numpy<=2.0.1,>=1.22.0"]
-requires-python = ">=3.8.0"
+requires-python = ">=3.9.0"
classifiers = [
"Intended Audience :: Developers",
"Natural Language :: English",
@@ -16,7 +16,6 @@ classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
@@ -41,17 +40,68 @@ pg8000 = ["pg8000>=1.13.1"]
snowflake = ["snowflake-connector-python[pandas]>=2.1.2"]
docs = ["sphinx", "sphinx_rtd_theme"]
tests = ["hypothesis", "pytest", "pytest-cov"]
-qa = ["pre-commit", "black", "isort"]
+qa = ["pre-commit", "ruff==0.5.7"]
build = ["build", "twine", "wheel"]
edgetest = ["edgetest", "edgetest-conda"]
dev = ["locopy[tests]", "locopy[docs]", "locopy[qa]", "locopy[build]"]
-[isort]
-multi_line_output = 3
-include_trailing_comma = true
-force_grid_wrap = 0
-use_parentheses = true
-line_length = 88
+# Linters, formatters and type checkers
+[tool.ruff]
+extend-include = ["*.ipynb"]
+target-version = "py39"
+src = ["src"]
+
+
+[tool.ruff.lint]
+preview = true
+select = [
+ "E", # pycodestyle errors
+ "W", # pycodestyle warnings
+ "F", # pyflakes
+ "D", # pydocstyle
+ "I", # isort
+ "UP", # pyupgrade
+ "B", # flake8-bugbear
+ # "A", # flake8-builtins
+ "C4", # flake8-comprehensions
+ #"C901", # mccabe complexity
+ # "G", # flake8-logging-format
+ "T20", # flake8-print
+ "TID252", # flake8-tidy-imports ban relative imports
+ # "ARG", # flake8-unused-arguments
+ "SIM", # flake8-simplify
+ "NPY", # numpy rules
+ "LOG", # flake8-logging
+ "RUF", # Ruff errors
+]
+
+
+ignore = [
+ "E111", # Check indentation level. Using formatter instead.
+ "E114", # Check indentation level. Using formatter instead.
+ "E117", # Check indentation level. Using formatter instead.
+ "E203", # Check whitespace. Using formatter instead.
+ "E501", # Line too long. Using formatter instead.
+ "D206", # Docstring indentation. Using formatter instead.
+ "D300", # Use triple single quotes. Using formatter instead.
+ "SIM108", # Use ternary operator instead of if-else blocks.
+ "SIM105", # Use `contextlib.suppress(FileNotFoundError)` instead of `try`-`except`-`pass`
+ "UP035", # `typing.x` is deprecated, use `x` instead
+ "UP006", # `typing.x` is deprecated, use `x` instead
+]
+
+
+[tool.ruff.lint.per-file-ignores]
+"__init__.py" = ["E402"]
+"**/{tests,docs}/*" = ["E402", "D", "F841", "ARG"]
+
+
+[tool.ruff.lint.flake8-tidy-imports]
+ban-relative-imports = "all"
+
+
+[tool.ruff.lint.pydocstyle]
+convention = "numpy"
[edgetest.envs.core]
python_version = "3.9"
diff --git a/tests/test_database.py b/tests/test_database.py
index 3366759..205e605 100644
--- a/tests/test_database.py
+++ b/tests/test_database.py
@@ -27,7 +27,6 @@
import psycopg2
import pytest
import snowflake.connector
-
from locopy import Database
from locopy.errors import CredentialsError, DBError
@@ -57,7 +56,12 @@ def test_database_constructor(credentials, dbapi):
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_database_constructor_kwargs(dbapi):
d = Database(
- dbapi=dbapi, host="host", port="port", database="database", user="user", password="password"
+ dbapi=dbapi,
+ host="host",
+ port="port",
+ database="database",
+ user="user",
+ password="password",
)
assert d.connection["host"] == "host"
assert d.connection["port"] == "port"
@@ -132,7 +136,11 @@ def test_connect(credentials, dbapi):
b = Database(dbapi=dbapi, **credentials)
b.connect()
mock_connect.assert_called_with(
- host="host", user="user", port="port", password="password", database="database"
+ host="host",
+ user="user",
+ port="port",
+ password="password",
+ database="database",
)
credentials["extra"] = 123
@@ -182,11 +190,12 @@ def test_disconnect_no_conn(credentials, dbapi):
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_execute(credentials, dbapi):
- with mock.patch(dbapi.__name__ + ".connect") as mock_connect:
- with Database(dbapi=dbapi, **credentials) as test:
- print(test)
- test.execute("SELECT * FROM some_table")
- assert test.cursor.execute.called is True
+ with (
+ mock.patch(dbapi.__name__ + ".connect") as mock_connect,
+ Database(dbapi=dbapi, **credentials) as test,
+ ):
+ test.execute("SELECT * FROM some_table")
+ assert test.cursor.execute.called is True
@pytest.mark.parametrize("dbapi", DBAPIS)
@@ -201,18 +210,24 @@ def test_execute_no_connection_exception(credentials, dbapi):
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_execute_sql_exception(credentials, dbapi):
- with mock.patch(dbapi.__name__ + ".connect") as mock_connect:
- with Database(dbapi=dbapi, **credentials) as test:
- test.cursor.execute.side_effect = Exception("SQL Exception")
- with pytest.raises(DBError):
- test.execute("SELECT * FROM some_table")
+ with (
+ mock.patch(dbapi.__name__ + ".connect") as mock_connect,
+ Database(dbapi=dbapi, **credentials) as test,
+ ):
+ test.cursor.execute.side_effect = Exception("SQL Exception")
+ with pytest.raises(DBError):
+ test.execute("SELECT * FROM some_table")
@pytest.mark.parametrize("dbapi", DBAPIS)
@mock.patch("pandas.DataFrame")
def test_to_dataframe_all(mock_pandas, credentials, dbapi):
with mock.patch(dbapi.__name__ + ".connect") as mock_connect:
- mock_connect.return_value.cursor.return_value.fetchall.return_value = [(1, 2), (2, 3), (3,)]
+ mock_connect.return_value.cursor.return_value.fetchall.return_value = [
+ (1, 2),
+ (2, 3),
+ (3,),
+ ]
with Database(dbapi=dbapi, **credentials) as test:
test.execute("SELECT 'hello world' AS fld")
df = test.to_dataframe()
@@ -256,11 +271,17 @@ def test_get_column_names(credentials, dbapi):
with Database(dbapi=dbapi, **credentials) as test:
assert test.column_names() == ["col1", "col2"]
- mock_connect.return_value.cursor.return_value.description = [("COL1",), ("COL2",)]
+ mock_connect.return_value.cursor.return_value.description = [
+ ("COL1",),
+ ("COL2",),
+ ]
with Database(dbapi=dbapi, **credentials) as test:
assert test.column_names() == ["col1", "col2"]
- mock_connect.return_value.cursor.return_value.description = (("COL1",), ("COL2",))
+ mock_connect.return_value.cursor.return_value.description = (
+ ("COL1",),
+ ("COL2",),
+ )
with Database(dbapi=dbapi, **credentials) as test:
assert test.column_names() == ["col1", "col2"]
@@ -274,8 +295,16 @@ def cols():
test.column_names = cols
test.cursor = [(1, 2), (2, 3), (3,)]
- assert list(test.to_dict()) == [{"col1": 1, "col2": 2}, {"col1": 2, "col2": 3}, {"col1": 3}]
+ assert list(test.to_dict()) == [
+ {"col1": 1, "col2": 2},
+ {"col1": 2, "col2": 3},
+ {"col1": 3},
+ ]
test.column_names = cols
test.cursor = [("a", 2), (2, 3), ("b",)]
- assert list(test.to_dict()) == [{"col1": "a", "col2": 2}, {"col1": 2, "col2": 3}, {"col1": "b"}]
+ assert list(test.to_dict()) == [
+ {"col1": "a", "col2": 2},
+ {"col1": 2, "col2": 3},
+ {"col1": "b"},
+ ]
diff --git a/tests/test_integration.py b/tests/test_integration.py
index ca8fb87..343a690 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -25,14 +25,13 @@
from pathlib import Path
import boto3
+import locopy
import numpy as np
import pandas as pd
import pg8000
import psycopg2
import pytest
-import locopy
-
DBAPIS = [pg8000, psycopg2]
INTEGRATION_CREDS = str(Path.home()) + "/.locopyrc"
S3_BUCKET = "locopy-integration-testing"
@@ -61,7 +60,6 @@ def s3_bucket():
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_redshift_execute_single_rows(dbapi):
-
expected = pd.DataFrame({"field_1": [1], "field_2": [2]})
with locopy.Redshift(dbapi=dbapi, **CREDS_DICT) as test:
test.execute("SELECT 1 AS field_1, 2 AS field_2 ")
@@ -73,7 +71,6 @@ def test_redshift_execute_single_rows(dbapi):
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_redshift_execute_multiple_rows(dbapi):
-
expected = pd.DataFrame({"field_1": [1, 2], "field_2": [1, 2]})
with locopy.Redshift(dbapi=dbapi, **CREDS_DICT) as test:
test.execute(
@@ -90,7 +87,6 @@ def test_redshift_execute_multiple_rows(dbapi):
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_s3_upload_download_file(s3_bucket, dbapi):
-
s3 = locopy.S3(**CREDS_DICT)
s3.upload_to_s3(LOCAL_FILE, S3_BUCKET, "myfile.txt")
@@ -104,7 +100,6 @@ def test_s3_upload_download_file(s3_bucket, dbapi):
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_copy(s3_bucket, dbapi):
-
with locopy.Redshift(dbapi=dbapi, **CREDS_DICT) as redshift:
redshift.execute(
"CREATE TEMPORARY TABLE locopy_integration_testing (id INTEGER, variable VARCHAR(20)) DISTKEY(variable)"
@@ -135,7 +130,6 @@ def test_copy(s3_bucket, dbapi):
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_copy_split_ignore(s3_bucket, dbapi):
-
with locopy.Redshift(dbapi=dbapi, **CREDS_DICT) as redshift:
redshift.execute(
"CREATE TEMPORARY TABLE locopy_integration_testing (id INTEGER, variable VARCHAR(20)) DISTKEY(variable)"
@@ -163,13 +157,12 @@ def test_copy_split_ignore(s3_bucket, dbapi):
for i, result in enumerate(results):
assert result[0] == expected[i][0]
assert result[1] == expected[i][1]
- os.remove(LOCAL_FILE_HEADER + ".{0}".format(i))
+ os.remove(LOCAL_FILE_HEADER + f".{i}")
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_unload(s3_bucket, dbapi):
-
with locopy.Redshift(dbapi=dbapi, **CREDS_DICT) as redshift:
redshift.execute(
"CREATE TEMPORARY TABLE locopy_integration_testing AS SELECT ('2017-12-31'::date + row_number() over (order by 1))::date from SVV_TABLES LIMIT 5"
@@ -231,7 +224,6 @@ def test_unload_raw_unload_path(s3_bucket, dbapi):
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_insert_dataframe_to_table(s3_bucket, dbapi):
-
with locopy.Redshift(dbapi=dbapi, **CREDS_DICT) as redshift:
redshift.insert_dataframe_to_table(TEST_DF, "locopy_df_test", create=True)
redshift.execute("SELECT a, b, c FROM locopy_df_test ORDER BY a ASC")
diff --git a/tests/test_integration_sf.py b/tests/test_integration_sf.py
index 60f9bcb..f5c7ab4 100644
--- a/tests/test_integration_sf.py
+++ b/tests/test_integration_sf.py
@@ -26,13 +26,12 @@
from pathlib import Path
import boto3
+import locopy
import numpy as np
import pandas as pd
import pytest
import snowflake.connector
-import locopy
-
DBAPIS = [snowflake.connector]
INTEGRATION_CREDS = str(Path.home()) + os.sep + ".locopy-sfrc"
S3_BUCKET = "locopy-integration-testing"
@@ -60,7 +59,6 @@ def s3_bucket():
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_snowflake_execute_single_rows(dbapi):
-
expected = pd.DataFrame({"field_1": [1], "field_2": [2]})
with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test:
test.execute("SELECT 1 AS field_1, 2 AS field_2 ")
@@ -73,11 +71,12 @@ def test_snowflake_execute_single_rows(dbapi):
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_snowflake_execute_multiple_rows(dbapi):
-
expected = pd.DataFrame({"field_1": [1, 2], "field_2": [1, 2]})
with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test:
test.execute(
- "SELECT 1 AS field_1, 1 AS field_2 " "UNION " "SELECT 2 AS field_1, 2 AS field_2"
+ "SELECT 1 AS field_1, 1 AS field_2 "
+ "UNION "
+ "SELECT 2 AS field_1, 2 AS field_2"
)
df = test.to_dataframe()
df.columns = [c.lower() for c in df.columns]
@@ -89,7 +88,6 @@ def test_snowflake_execute_multiple_rows(dbapi):
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_upload_download_internal(dbapi):
-
with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test:
# delete if exists
test.execute("REMOVE @~/staged/mock_file_dl.txt")
@@ -114,7 +112,6 @@ def test_upload_download_internal(dbapi):
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_copy(dbapi):
-
with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test:
test.upload_to_internal(LOCAL_FILE, "@~/staged/")
test.execute("USE SCHEMA {}".format(CREDS_DICT["schema"]))
@@ -144,7 +141,6 @@ def test_copy(dbapi):
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_copy_json(dbapi):
-
with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test:
test.upload_to_internal(LOCAL_FILE_JSON, "@~/staged/")
test.execute("USE SCHEMA {}".format(CREDS_DICT["schema"]))
@@ -176,7 +172,6 @@ def test_copy_json(dbapi):
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_to_dataframe(dbapi):
-
with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test:
test.upload_to_internal(LOCAL_FILE_JSON, "@~/staged/")
test.execute("USE SCHEMA {}".format(CREDS_DICT["schema"]))
@@ -197,11 +192,17 @@ def test_to_dataframe(dbapi):
result = test.to_dataframe()
result.columns = [c.lower() for c in result.columns]
expected = pd.DataFrame(
- [('"Belmont"', '"92567"'), ('"Lexington"', '"75836"'), ('"Winchester"', '"89921"'),],
+ [
+ ('"Belmont"', '"92567"'),
+ ('"Lexington"', '"75836"'),
+ ('"Winchester"', '"89921"'),
+ ],
columns=["variable:location:city", "variable:price"],
)
- assert (result["variable:location:city"] == expected["variable:location:city"]).all()
+ assert (
+ result["variable:location:city"] == expected["variable:location:city"]
+ ).all()
assert (result["variable:price"] == expected["variable:price"]).all()
# with size of 2
@@ -211,11 +212,16 @@ def test_to_dataframe(dbapi):
result = test.to_dataframe(size=2)
result.columns = [c.lower() for c in result.columns]
expected = pd.DataFrame(
- [('"Belmont"', '"92567"'), ('"Lexington"', '"75836"'),],
+ [
+ ('"Belmont"', '"92567"'),
+ ('"Lexington"', '"75836"'),
+ ],
columns=["variable:location:city", "variable:price"],
)
- assert (result["variable:location:city"] == expected["variable:location:city"]).all()
+ assert (
+ result["variable:location:city"] == expected["variable:location:city"]
+ ).all()
assert (result["variable:price"] == expected["variable:price"]).all()
# with non-select query
@@ -227,7 +233,6 @@ def test_to_dataframe(dbapi):
@pytest.mark.integration
@pytest.mark.parametrize("dbapi", DBAPIS)
def test_insert_dataframe_to_table(dbapi):
-
with locopy.Snowflake(dbapi=dbapi, **CREDS_DICT) as test:
test.insert_dataframe_to_table(TEST_DF, "test", create=True)
test.execute("SELECT a, b, c FROM test ORDER BY a ASC")
@@ -250,7 +255,15 @@ def test_insert_dataframe_to_table(dbapi):
results = test.cursor.fetchall()
test.execute("drop table if exists test_2")
- expected = [(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e"), (6, "f"), (7, "g")]
+ expected = [
+ (1, "a"),
+ (2, "b"),
+ (3, "c"),
+ (4, "d"),
+ (5, "e"),
+ (6, "f"),
+ (7, "g"),
+ ]
assert len(expected) == len(results)
for i, result in enumerate(results):
@@ -280,5 +293,5 @@ def test_insert_dataframe_to_table(dbapi):
assert len(expected) == len(results)
for i, result in enumerate(results):
- for j, item in enumerate(result):
+ for j, _item in enumerate(result):
assert result[j] == expected[i][j]
diff --git a/tests/test_redshift.py b/tests/test_redshift.py
index d077ddf..55a5f2f 100644
--- a/tests/test_redshift.py
+++ b/tests/test_redshift.py
@@ -24,13 +24,12 @@
from collections import OrderedDict
from unittest import mock
+import locopy
import pg8000
import psycopg2
import pytest
-
-import locopy
from locopy import Redshift
-from locopy.errors import CredentialsError, DBError
+from locopy.errors import DBError
PROFILE = "test"
GOOD_CONFIG_YAML = """
@@ -50,7 +49,9 @@ def test_add_default_copy_options():
"COMPUPDATE ON",
"TRUNCATECOLUMNS",
]
- assert locopy.redshift.add_default_copy_options(["DATEFORMAT 'other'", "NULL AS 'blah'"]) == [
+ assert locopy.redshift.add_default_copy_options(
+ ["DATEFORMAT 'other'", "NULL AS 'blah'"]
+ ) == [
"DATEFORMAT 'other'",
"NULL AS 'blah'",
"COMPUPDATE ON",
@@ -59,9 +60,9 @@ def test_add_default_copy_options():
def test_combine_copy_options():
- assert locopy.redshift.combine_copy_options(locopy.redshift.add_default_copy_options()) == (
- "DATEFORMAT 'auto' COMPUPDATE " "ON TRUNCATECOLUMNS"
- )
+ assert locopy.redshift.combine_copy_options(
+ locopy.redshift.add_default_copy_options()
+ ) == ("DATEFORMAT 'auto' COMPUPDATE " "ON TRUNCATECOLUMNS")
@pytest.mark.parametrize("dbapi", DBAPIS)
@@ -70,7 +71,7 @@ def test_constructor(mock_session, credentials, dbapi):
r = Redshift(profile=PROFILE, dbapi=dbapi, **credentials)
mock_session.assert_called_with(profile_name=PROFILE)
assert r.profile == PROFILE
- assert r.kms_key == None
+ assert r.kms_key is None
assert r.connection["host"] == "host"
assert r.connection["port"] == "port"
assert r.connection["database"] == "database"
@@ -91,7 +92,7 @@ def test_constructor_yaml(mock_session, dbapi):
r = Redshift(profile=PROFILE, dbapi=dbapi, config_yaml="some_config.yml")
mock_session.assert_called_with(profile_name=PROFILE)
assert r.profile == PROFILE
- assert r.kms_key == None
+ assert r.kms_key is None
assert r.connection["host"] == "host"
assert r.connection["port"] == "port"
assert r.connection["database"] == "database"
@@ -140,21 +141,18 @@ def test_copy_parquet(mock_execute, mock_session, credentials, dbapi):
r = Redshift(profile=PROFILE, dbapi=dbapi, **credentials)
r.connect()
r.copy("table", s3path="path", delim=None, copy_options=["PARQUET"])
- test_sql = (
- "COPY {0} FROM '{1}' "
- "CREDENTIALS '{2}' "
- "{3};".format("table", "path", r._credentials_string(), "PARQUET")
+ test_sql = "COPY {} FROM '{}' " "CREDENTIALS '{}' " "{};".format(
+ "table", "path", r._credentials_string(), "PARQUET"
)
assert mock_execute.called_with(test_sql, commit=True)
mock_execute.reset_mock()
mock_session.reset_mock()
r.copy("table", s3path="path", delim=None)
- test_sql = (
- "COPY {0} FROM '{1}' "
- "CREDENTIALS '{2}' "
- "{3};".format(
- "table", "path", r._credentials_string(), locopy.redshift.add_default_copy_options()
- )
+ test_sql = "COPY {} FROM '{}' " "CREDENTIALS '{}' " "{};".format(
+ "table",
+ "path",
+ r._credentials_string(),
+ locopy.redshift.add_default_copy_options(),
)
assert mock_execute.called_with(test_sql, commit=True)
@@ -261,7 +259,10 @@ def reset_mocks():
# mock_remove.assert_called_with("/path/local_file.2")
mock_s3_upload.assert_has_calls(expected_calls_no_folder_gzip)
mock_rs_copy.assert_called_with(
- "table_name", "s3://s3_bucket/local_file", "|", copy_options=["SOME OPTION", "GZIP"]
+ "table_name",
+ "s3://s3_bucket/local_file",
+ "|",
+ copy_options=["SOME OPTION", "GZIP"],
)
assert mock_s3_delete.called_with("s3_bucket", "local_file.0.gz")
assert mock_s3_delete.called_with("s3_bucket", "local_file.1.gz")
@@ -335,7 +336,10 @@ def reset_mocks():
"/path/local_file.txt", "s3_bucket", "test/local_file.txt"
)
mock_rs_copy.assert_called_with(
- "table_name", "s3://s3_bucket/test/local_file", "|", copy_options=["SOME OPTION"]
+ "table_name",
+ "s3://s3_bucket/test/local_file",
+ "|",
+ copy_options=["SOME OPTION"],
)
assert not mock_s3_delete.called
@@ -366,7 +370,10 @@ def reset_mocks():
# assert not mock_remove.called
mock_s3_upload.assert_has_calls(expected_calls_folder)
mock_rs_copy.assert_called_with(
- "table_name", "s3://s3_bucket/test/local_file", "|", copy_options=["SOME OPTION"]
+ "table_name",
+ "s3://s3_bucket/test/local_file",
+ "|",
+ copy_options=["SOME OPTION"],
)
assert mock_s3_delete.called_with("s3_bucket", "test/local_file.0")
assert mock_s3_delete.called_with("s3_bucket", "test/local_file.1")
@@ -467,7 +474,7 @@ def reset_mocks():
mock_split_file.return_value = ["/path/local_file.txt"]
mock_compress_file_list.return_value = ["/path/local_file.txt.gz"]
- #### neither ignore or split only
+ # neither ignore or split only
r.load_and_copy("/path/local_file.txt", "s3_bucket", "table_name", delim="|")
# assert
@@ -482,7 +489,7 @@ def reset_mocks():
)
assert not mock_s3_delete.called, "Only delete when explicit"
- #### ignore only
+ # ignore only
reset_mocks()
r.load_and_copy(
"/path/local_file.txt",
@@ -507,7 +514,7 @@ def reset_mocks():
)
assert not mock_s3_delete.called, "Only delete when explicit"
- #### split only
+ # split only
reset_mocks()
mock_split_file.return_value = [
"/path/local_file.0",
@@ -520,7 +527,12 @@ def reset_mocks():
"/path/local_file.2.gz",
]
r.load_and_copy(
- "/path/local_file", "s3_bucket", "table_name", delim="|", splits=3, delete_s3_after=True
+ "/path/local_file",
+ "s3_bucket",
+ "table_name",
+ delim="|",
+ splits=3,
+ delete_s3_after=True,
)
# assert
@@ -538,7 +550,7 @@ def reset_mocks():
assert mock_s3_delete.called_with("s3_bucket", "local_file.1.gz")
assert mock_s3_delete.called_with("s3_bucket", "local_file.2.gz")
- #### split and ignore
+ # split and ignore
reset_mocks()
mock_split_file.return_value = [
"/path/local_file.0",
@@ -580,7 +592,6 @@ def reset_mocks():
@pytest.mark.parametrize("dbapi", DBAPIS)
@mock.patch("locopy.s3.Session")
def test_redshiftcopy(mock_session, credentials, dbapi):
-
with mock.patch(dbapi.__name__ + ".connect") as mock_connect:
r = locopy.Redshift(dbapi=dbapi, **credentials)
r.connect()
@@ -589,13 +600,9 @@ def test_redshiftcopy(mock_session, credentials, dbapi):
(
mock_connect.return_value.cursor.return_value.execute.assert_called_with(
"COPY table FROM 's3bucket' CREDENTIALS "
- "'aws_access_key_id={0};aws_secret_access_key={1};token={2}' "
+ f"'aws_access_key_id={r.session.get_credentials().access_key};aws_secret_access_key={r.session.get_credentials().secret_key};token={r.session.get_credentials().token}' "
"DELIMITER '|' DATEFORMAT 'auto' COMPUPDATE ON "
- "TRUNCATECOLUMNS;".format(
- r.session.get_credentials().access_key,
- r.session.get_credentials().secret_key,
- r.session.get_credentials().token,
- ),
+ "TRUNCATECOLUMNS;",
(),
)
)
@@ -606,13 +613,9 @@ def test_redshiftcopy(mock_session, credentials, dbapi):
(
mock_connect.return_value.cursor.return_value.execute.assert_called_with(
"COPY table FROM 's3bucket' CREDENTIALS "
- "'aws_access_key_id={0};aws_secret_access_key={1};token={2}' "
+ f"'aws_access_key_id={r.session.get_credentials().access_key};aws_secret_access_key={r.session.get_credentials().secret_key};token={r.session.get_credentials().token}' "
"DELIMITER '\t' DATEFORMAT 'auto' COMPUPDATE ON "
- "TRUNCATECOLUMNS;".format(
- r.session.get_credentials().access_key,
- r.session.get_credentials().secret_key,
- r.session.get_credentials().token,
- ),
+ "TRUNCATECOLUMNS;",
(),
)
)
@@ -622,13 +625,9 @@ def test_redshiftcopy(mock_session, credentials, dbapi):
(
mock_connect.return_value.cursor.return_value.execute.assert_called_with(
"COPY table FROM 's3bucket' CREDENTIALS "
- "'aws_access_key_id={0};aws_secret_access_key={1};token={2}' "
+ f"'aws_access_key_id={r.session.get_credentials().access_key};aws_secret_access_key={r.session.get_credentials().secret_key};token={r.session.get_credentials().token}' "
"DATEFORMAT 'auto' COMPUPDATE ON "
- "TRUNCATECOLUMNS;".format(
- r.session.get_credentials().access_key,
- r.session.get_credentials().secret_key,
- r.session.get_credentials().token,
- ),
+ "TRUNCATECOLUMNS;",
(),
)
)
@@ -638,7 +637,6 @@ def test_redshiftcopy(mock_session, credentials, dbapi):
@mock.patch("locopy.s3.Session")
@mock.patch("locopy.database.Database._is_connected")
def test_redshiftcopy_exception(mock_connected, mock_session, credentials, dbapi):
-
with mock.patch(dbapi.__name__ + ".connect") as mock_connect:
r = locopy.Redshift(dbapi=dbapi, **credentials)
mock_connected.return_value = False
@@ -691,13 +689,13 @@ def reset_mocks():
r = locopy.Redshift(dbapi=dbapi, **credentials)
##
- ## Test 1: check that basic export pipeline functions are called
+ # Test 1: check that basic export pipeline functions are called
mock_unload_generated_files.return_value = ["dummy_file"]
mock_download_list_from_s3.return_value = ["s3.file"]
mock_get_col_names.return_value = ["dummy_col_name"]
mock_generate_unload_path.return_value = "dummy_s3_path"
- ## ensure nothing is returned when read=False
+ # ensure nothing is returned when read=False
r.unload_and_copy(
query="query",
s3_bucket="s3_bucket",
@@ -710,7 +708,9 @@ def reset_mocks():
)
assert mock_unload_generated_files.called
- assert not mock_write.called, "write_file should only be called " "if export_path != False"
+ assert not mock_write.called, (
+ "write_file should only be called " "if export_path != False"
+ )
mock_generate_unload_path.assert_called_with("s3_bucket", None)
mock_get_col_names.assert_called_with("query")
mock_unload.assert_called_with(
@@ -719,7 +719,7 @@ def reset_mocks():
assert not mock_delete_list_from_s3.called
##
- ## Test 2: different delimiter
+ # Test 2: different delimiter
reset_mocks()
mock_unload_generated_files.return_value = ["dummy_file"]
mock_download_list_from_s3.return_value = ["s3.file"]
@@ -736,14 +736,16 @@ def reset_mocks():
parallel_off=True,
)
- ## check that unload options are modified based on supplied args
+ # check that unload options are modified based on supplied args
mock_unload.assert_called_with(
- query="query", s3path="dummy_s3_path", unload_options=["DELIMITER '|'", "PARALLEL OFF"]
+ query="query",
+ s3path="dummy_s3_path",
+ unload_options=["DELIMITER '|'", "PARALLEL OFF"],
)
assert not mock_delete_list_from_s3.called
##
- ## Test 2.5: delimiter is none
+ # Test 2.5: delimiter is none
reset_mocks()
mock_unload_generated_files.return_value = ["dummy_file"]
mock_download_list_from_s3.return_value = ["s3.file"]
@@ -760,32 +762,32 @@ def reset_mocks():
parallel_off=True,
)
- ## check that unload options are modified based on supplied args
+ # check that unload options are modified based on supplied args
mock_unload.assert_called_with(
query="query", s3path="dummy_s3_path", unload_options=["PARALLEL OFF"]
)
assert not mock_delete_list_from_s3.called
##
- ## Test 3: ensure exception is raised when no column names are retrieved
+ # Test 3: ensure exception is raised when no column names are retrieved
reset_mocks()
mock_unload_generated_files.return_value = ["dummy_file"]
mock_generate_unload_path.return_value = "dummy_s3_path"
mock_get_col_names.return_value = None
- with pytest.raises(Exception):
+ with pytest.raises(DBError):
r.unload_and_copy("query", "s3_bucket", None)
##
- ## Test 4: ensure exception is raised when no files are returned
+ # Test 4: ensure exception is raised when no files are returned
reset_mocks()
mock_generate_unload_path.return_value = "dummy_s3_path"
mock_get_col_names.return_value = ["dummy_col_name"]
mock_unload_generated_files.return_value = None
- with pytest.raises(Exception):
+ with pytest.raises(DBError):
r.unload_and_copy("query", "s3_bucket", None)
##
- ## Test 5: ensure file writing is initiated when export_path is supplied
+ # Test 5: ensure file writing is initiated when export_path is supplied
reset_mocks()
mock_get_col_names.return_value = ["dummy_col_name"]
mock_download_list_from_s3.return_value = ["s3.file"]
@@ -801,18 +803,20 @@ def reset_mocks():
delete_s3_after=True,
parallel_off=False,
)
- mock_concat.assert_called_with(mock_download_list_from_s3.return_value, "my_output.csv")
+ mock_concat.assert_called_with(
+ mock_download_list_from_s3.return_value, "my_output.csv"
+ )
assert mock_write.called
assert mock_delete_list_from_s3.called_with("s3_bucket", "my_output.csv")
##
- ## Test 6: raw_unload_path check
+ # Test 6: raw_unload_path check
reset_mocks()
mock_get_col_names.return_value = ["dummy_col_name"]
mock_download_list_from_s3.return_value = ["s3.file"]
mock_generate_unload_path.return_value = "dummy_s3_path"
mock_unload_generated_files.return_value = ["/dummy_file"]
- ## ensure nothing is returned when read=False
+ # ensure nothing is returned when read=False
r.unload_and_copy(
query="query",
s3_bucket="s3_bucket",
@@ -833,7 +837,7 @@ def test_unload_generated_files(mock_session, credentials, dbapi):
r = locopy.Redshift(dbapi=dbapi, **credentials)
r.connect()
r._unload_generated_files()
- assert r._unload_generated_files() == None
+ assert r._unload_generated_files() is None
mock_connect.return_value.cursor.return_value.fetchall.return_value = [
["File1 "],
@@ -844,10 +848,10 @@ def test_unload_generated_files(mock_session, credentials, dbapi):
r._unload_generated_files()
assert r._unload_generated_files() == ["File1", "File2"]
- mock_connect.return_value.cursor.return_value.execute.side_effect = Exception()
+ mock_connect.return_value.cursor.return_value.execute.side_effect = DBError()
r = locopy.Redshift(dbapi=dbapi, **credentials)
r.connect()
- with pytest.raises(Exception):
+ with pytest.raises(DBError):
r._unload_generated_files()
@@ -857,19 +861,24 @@ def test_get_column_names(mock_session, credentials, dbapi):
with mock.patch(dbapi.__name__ + ".connect") as mock_connect:
r = locopy.Redshift(dbapi=dbapi, **credentials)
r.connect()
- assert r._get_column_names("query") == None
+ assert r._get_column_names("query") is None
sql = "SELECT * FROM (query) WHERE 1 = 0"
- assert mock_connect.return_value.cursor.return_value.execute.called_with(sql, ())
+ assert mock_connect.return_value.cursor.return_value.execute.called_with(
+ sql, ()
+ )
- mock_connect.return_value.cursor.return_value.description = [["COL1 "], ["COL2 "]]
+ mock_connect.return_value.cursor.return_value.description = [
+ ["COL1 "],
+ ["COL2 "],
+ ]
r = locopy.Redshift(dbapi=dbapi, **credentials)
r.connect()
assert r._get_column_names("query") == ["COL1", "COL2"]
- mock_connect.return_value.cursor.return_value.execute.side_effect = Exception()
+ mock_connect.return_value.cursor.return_value.execute.side_effect = DBError()
r = locopy.Redshift(dbapi=dbapi, **credentials)
r.connect()
- with pytest.raises(Exception):
+ with pytest.raises(DBError):
r._get_column_names("query")
@@ -888,13 +897,13 @@ def testunload(mock_session, credentials, dbapi):
def testunload_no_connection(mock_session, credentials, dbapi):
with mock.patch(dbapi.__name__ + ".connect") as mock_connect:
r = locopy.Redshift(dbapi=dbapi, **credentials)
- with pytest.raises(Exception):
+ with pytest.raises(DBError):
r.unload("query", "path")
- mock_connect.return_value.cursor.return_value.execute.side_effect = Exception()
+ mock_connect.return_value.cursor.return_value.execute.side_effect = DBError()
r = locopy.Redshift(dbapi=dbapi, **credentials)
r.connect()
- with pytest.raises(Exception):
+ with pytest.raises(DBError):
r.unload("query", "path")
@@ -932,7 +941,9 @@ def testinsert_dataframe_to_table(mock_session, credentials, dbapi):
test_df,
"database.schema.test",
create=True,
- metadata=OrderedDict([("col1", "int"), ("col2", "varchar"), ("col3", "date")]),
+ metadata=OrderedDict(
+ [("col1", "int"), ("col2", "varchar"), ("col3", "date")]
+ ),
)
mock_connect.return_value.cursor.return_value.execute.assert_any_call(
@@ -943,11 +954,15 @@ def testinsert_dataframe_to_table(mock_session, credentials, dbapi):
(),
)
- r.insert_dataframe_to_table(test_df, "database.schema.test", create=False, batch_size=1)
+ r.insert_dataframe_to_table(
+ test_df, "database.schema.test", create=False, batch_size=1
+ )
mock_connect.return_value.cursor.return_value.execute.assert_any_call(
- "INSERT INTO database.schema.test (a,b,c) VALUES ('1', 'x', '2011-01-01')", ()
+ "INSERT INTO database.schema.test (a,b,c) VALUES ('1', 'x', '2011-01-01')",
+ (),
)
mock_connect.return_value.cursor.return_value.execute.assert_any_call(
- "INSERT INTO database.schema.test (a,b,c) VALUES ('2', 'y', '2001-04-02')", ()
+ "INSERT INTO database.schema.test (a,b,c) VALUES ('2', 'y', '2001-04-02')",
+ (),
)
diff --git a/tests/test_s3.py b/tests/test_s3.py
index 1d23b81..53f7f91 100644
--- a/tests/test_s3.py
+++ b/tests/test_s3.py
@@ -25,12 +25,11 @@
from unittest import mock
import hypothesis.strategies as st
+import locopy
import pg8000
import psycopg2
import pytest
from hypothesis import given
-
-import locopy
from locopy.errors import (
S3CredentialsError,
S3DeletionError,
@@ -63,7 +62,7 @@
def test_mock_s3_session_profile_without_kms(profile, mock_session, dbapi):
s = locopy.S3(profile=profile)
mock_session.assert_called_with(profile_name=profile)
- assert s.kms_key == None
+ assert s.kms_key is None
@pytest.mark.parametrize("dbapi", DBAPIS)
@@ -80,7 +79,7 @@ def test_mock_s3_session_profile_with_kms(input_kms_key, profile, mock_session,
def test_mock_s3_session_profile_without_any(mock_session, dbapi):
s = locopy.S3()
mock_session.assert_called_with(profile_name=None)
- assert s.kms_key == None
+ assert s.kms_key is None
@pytest.mark.parametrize("dbapi", DBAPIS)
@@ -122,8 +121,8 @@ def test_get_credentials(mock_cred, aws_creds):
expected = "aws_access_key_id=access;" "aws_secret_access_key=secret"
assert cred_string == expected
- mock_cred.side_effect = Exception("Exception")
- with pytest.raises(Exception):
+ mock_cred.side_effect = S3CredentialsError("Exception")
+ with pytest.raises(S3CredentialsError):
locopy.S3()
@@ -139,7 +138,10 @@ def test_generate_s3_path(mock_session):
def test_generate_unload_path(mock_session):
s = locopy.S3()
assert s._generate_unload_path("TEST", "FOLDER/") == "s3://TEST/FOLDER/"
- assert s._generate_unload_path("TEST SPACE", "FOLDER SPACE/") == "s3://TEST SPACE/FOLDER SPACE/"
+ assert (
+ s._generate_unload_path("TEST SPACE", "FOLDER SPACE/")
+ == "s3://TEST SPACE/FOLDER SPACE/"
+ )
assert s._generate_unload_path("TEST", "PREFIX") == "s3://TEST/PREFIX"
assert s._generate_unload_path("TEST", None) == "s3://TEST"
@@ -267,7 +269,10 @@ def test_download_from_s3(mock_session, mock_config):
s = locopy.S3()
s.download_from_s3(S3_DEFAULT_BUCKET, LOCAL_TEST_FILE, LOCAL_TEST_FILE)
s.s3.download_file.assert_called_with(
- S3_DEFAULT_BUCKET, LOCAL_TEST_FILE, os.path.basename(LOCAL_TEST_FILE), Config=mock_config()
+ S3_DEFAULT_BUCKET,
+ LOCAL_TEST_FILE,
+ os.path.basename(LOCAL_TEST_FILE),
+ Config=mock_config(),
)
mock_config.side_effect = Exception()
@@ -316,7 +321,9 @@ def test_delete_list_from_s3_multiple_with_folder(mock_session, mock_delete):
mock.call("test_bucket", "test_folder/test.2"),
]
s = locopy.S3()
- s.delete_list_from_s3(["test_bucket/test_folder/test.1", "test_bucket/test_folder/test.2"])
+ s.delete_list_from_s3(
+ ["test_bucket/test_folder/test.1", "test_bucket/test_folder/test.2"]
+ )
mock_delete.assert_has_calls(calls)
@@ -331,7 +338,9 @@ def test_delete_list_from_s3_multiple_without_folder(mock_session, mock_delete):
@mock.patch("locopy.s3.S3.delete_from_s3")
@mock.patch("locopy.s3.Session")
-def test_delete_list_from_s3_single_with_folder_and_special_chars(mock_session, mock_delete):
+def test_delete_list_from_s3_single_with_folder_and_special_chars(
+ mock_session, mock_delete
+):
calls = [mock.call("test_bucket", r"test_folder/#$#@$@#$dffksdojfsdf\\\\\/test.1")]
s = locopy.S3()
s.delete_list_from_s3([r"test_bucket/test_folder/#$#@$@#$dffksdojfsdf\\\\\/test.1"])
@@ -344,22 +353,33 @@ def test_delete_list_from_s3_exception(mock_session, mock_delete):
s = locopy.S3()
mock_delete.side_effect = S3UploadError("Upload Exception")
with pytest.raises(S3UploadError):
- s.delete_list_from_s3(["test_bucket/test_folder/test.1", "test_bucket/test_folder/test.2"])
+ s.delete_list_from_s3(
+ ["test_bucket/test_folder/test.1", "test_bucket/test_folder/test.2"]
+ )
@mock.patch("locopy.s3.Session")
def test_parse_s3_url(mock_session):
s = locopy.S3()
- assert s.parse_s3_url("s3://bucket/folder/file.txt") == ("bucket", "folder/file.txt")
+ assert s.parse_s3_url("s3://bucket/folder/file.txt") == (
+ "bucket",
+ "folder/file.txt",
+ )
assert s.parse_s3_url("s3://bucket/folder/") == ("bucket", "folder/")
assert s.parse_s3_url("s3://bucket") == ("bucket", "")
- assert s.parse_s3_url(r"s3://bucket/!@#$%\\\/file.txt") == ("bucket", r"!@#$%\\\/file.txt")
+ assert s.parse_s3_url(r"s3://bucket/!@#$%\\\/file.txt") == (
+ "bucket",
+ r"!@#$%\\\/file.txt",
+ )
assert s.parse_s3_url("s3://") == ("", "")
assert s.parse_s3_url("bucket/folder/file.txt") == ("bucket", "folder/file.txt")
assert s.parse_s3_url("bucket/folder/") == ("bucket", "folder/")
assert s.parse_s3_url("bucket") == ("bucket", "")
- assert s.parse_s3_url(r"bucket/!@#$%\\\/file.txt") == ("bucket", r"!@#$%\\\/file.txt")
+ assert s.parse_s3_url(r"bucket/!@#$%\\\/file.txt") == (
+ "bucket",
+ r"!@#$%\\\/file.txt",
+ )
assert s.parse_s3_url("") == ("", "")
@@ -394,7 +414,10 @@ def test_download_list_from_s3_multiple(mock_session, mock_download):
]
s = locopy.S3()
res = s.download_list_from_s3(["s3://bucket/test.1", "s3://bucket/test.2"])
- assert res == [os.path.join(os.getcwd(), "test.1"), os.path.join(os.getcwd(), "test.2")]
+ assert res == [
+ os.path.join(os.getcwd(), "test.1"),
+ os.path.join(os.getcwd(), "test.2"),
+ ]
mock_download.assert_has_calls(calls)
@@ -407,8 +430,13 @@ def test_download_list_from_s3_multiple_with_localpath(mock_session, mock_downlo
mock.call("bucket", "test.2", os.path.join(tmp_path.name, "test.2")),
]
s = locopy.S3()
- res = s.download_list_from_s3(["s3://bucket/test.1", "s3://bucket/test.2"], tmp_path.name)
- assert res == [os.path.join(tmp_path.name, "test.1"), os.path.join(tmp_path.name, "test.2")]
+ res = s.download_list_from_s3(
+ ["s3://bucket/test.1", "s3://bucket/test.2"], tmp_path.name
+ )
+ assert res == [
+ os.path.join(tmp_path.name, "test.1"),
+ os.path.join(tmp_path.name, "test.2"),
+ ]
mock_download.assert_has_calls(calls)
tmp_path.cleanup()
diff --git a/tests/test_snowflake.py b/tests/test_snowflake.py
index a6b4aa2..d24e76b 100644
--- a/tests/test_snowflake.py
+++ b/tests/test_snowflake.py
@@ -27,13 +27,12 @@
from unittest import mock
import hypothesis.strategies as s
+import locopy
import pytest
import snowflake.connector
from hypothesis import HealthCheck, given, settings
-
-import locopy
from locopy import Snowflake
-from locopy.errors import CredentialsError, DBError
+from locopy.errors import DBError
PROFILE = "test"
KMS = "kms_test"
@@ -159,84 +158,92 @@ def test_with_connect(mock_session, sf_credentials):
sf.conn.cursor.return_value.execute.assert_any_call("USE SCHEMA schema", ())
mock_connect.side_effect = Exception("Connect Exception")
- with pytest.raises(DBError):
- with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf:
- sf.cursor
+ with (
+ pytest.raises(DBError),
+ Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf,
+ ):
+ sf.cursor # noqa: B018
@mock.patch("locopy.s3.Session")
def test_upload_to_internal(mock_session, sf_credentials):
- with mock.patch("snowflake.connector.connect") as mock_connect:
- with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf:
- sf.upload_to_internal("/some/file", "@~/internal")
- sf.conn.cursor.return_value.execute.assert_called_with(
- "PUT 'file:///some/file' @~/internal PARALLEL=4 AUTO_COMPRESS=True OVERWRITE=True",
- (),
- )
+ with (
+ mock.patch("snowflake.connector.connect") as mock_connect,
+ Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf,
+ ):
+ sf.upload_to_internal("/some/file", "@~/internal")
+ sf.conn.cursor.return_value.execute.assert_called_with(
+ "PUT 'file:///some/file' @~/internal PARALLEL=4 AUTO_COMPRESS=True OVERWRITE=True",
+ (),
+ )
- sf.upload_to_internal(
- "/some/file", "@~/internal", parallel=99, auto_compress=False
- )
- sf.conn.cursor.return_value.execute.assert_called_with(
- "PUT 'file:///some/file' @~/internal PARALLEL=99 AUTO_COMPRESS=False OVERWRITE=True",
- (),
- )
+ sf.upload_to_internal(
+ "/some/file", "@~/internal", parallel=99, auto_compress=False
+ )
+ sf.conn.cursor.return_value.execute.assert_called_with(
+ "PUT 'file:///some/file' @~/internal PARALLEL=99 AUTO_COMPRESS=False OVERWRITE=True",
+ (),
+ )
- sf.upload_to_internal("/some/file", "@~/internal", overwrite=False)
- sf.conn.cursor.return_value.execute.assert_called_with(
- "PUT 'file:///some/file' @~/internal PARALLEL=4 AUTO_COMPRESS=True OVERWRITE=False",
- (),
- )
+ sf.upload_to_internal("/some/file", "@~/internal", overwrite=False)
+ sf.conn.cursor.return_value.execute.assert_called_with(
+ "PUT 'file:///some/file' @~/internal PARALLEL=4 AUTO_COMPRESS=True OVERWRITE=False",
+ (),
+ )
- # exception
- sf.conn.cursor.return_value.execute.side_effect = Exception("PUT Exception")
- with pytest.raises(DBError):
- sf.upload_to_internal("/some/file", "@~/internal")
+ # exception
+ sf.conn.cursor.return_value.execute.side_effect = Exception("PUT Exception")
+ with pytest.raises(DBError):
+ sf.upload_to_internal("/some/file", "@~/internal")
@mock.patch("locopy.snowflake.PurePath", new=PureWindowsPath)
@mock.patch("locopy.s3.Session")
def test_upload_to_internal_windows(mock_session, sf_credentials):
- with mock.patch("snowflake.connector.connect") as mock_connect:
- with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf:
-
- sf.upload_to_internal(r"C:\some\file", "@~/internal")
- sf.conn.cursor.return_value.execute.assert_called_with(
- "PUT 'file://C:/some/file' @~/internal PARALLEL=4 AUTO_COMPRESS=True OVERWRITE=True",
- (),
- )
+ with (
+ mock.patch("snowflake.connector.connect") as mock_connect,
+ Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf,
+ ):
+ sf.upload_to_internal(r"C:\some\file", "@~/internal")
+ sf.conn.cursor.return_value.execute.assert_called_with(
+ "PUT 'file://C:/some/file' @~/internal PARALLEL=4 AUTO_COMPRESS=True OVERWRITE=True",
+ (),
+ )
@mock.patch("locopy.s3.Session")
def test_download_from_internal(mock_session, sf_credentials):
- with mock.patch("snowflake.connector.connect") as mock_connect:
- with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf:
- sf.download_from_internal("@~/internal", "/some/file")
- sf.conn.cursor.return_value.execute.assert_called_with(
- "GET @~/internal 'file:///some/file' PARALLEL=10", ()
- )
+ with (
+ mock.patch("snowflake.connector.connect") as mock_connect,
+ Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf,
+ ):
+ sf.download_from_internal("@~/internal", "/some/file")
+ sf.conn.cursor.return_value.execute.assert_called_with(
+ "GET @~/internal 'file:///some/file' PARALLEL=10", ()
+ )
- sf.download_from_internal("@~/internal", "/some/file", parallel=99)
- sf.conn.cursor.return_value.execute.assert_called_with(
- "GET @~/internal 'file:///some/file' PARALLEL=99", ()
- )
+ sf.download_from_internal("@~/internal", "/some/file", parallel=99)
+ sf.conn.cursor.return_value.execute.assert_called_with(
+ "GET @~/internal 'file:///some/file' PARALLEL=99", ()
+ )
- # exception
- sf.conn.cursor.return_value.execute.side_effect = Exception("GET Exception")
- with pytest.raises(DBError):
- sf.download_from_internal("@~/internal", "/some/file")
+ # exception
+ sf.conn.cursor.return_value.execute.side_effect = Exception("GET Exception")
+ with pytest.raises(DBError):
+ sf.download_from_internal("@~/internal", "/some/file")
@mock.patch("locopy.snowflake.PurePath", new=PureWindowsPath)
@mock.patch("locopy.s3.Session")
def test_download_from_internal_windows(mock_session, sf_credentials):
- with mock.patch("snowflake.connector.connect") as mock_connect:
- with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf:
-
- sf.download_from_internal("@~/internal", r"C:\some\file")
- sf.conn.cursor.return_value.execute.assert_called_with(
- "GET @~/internal 'file://C:/some/file' PARALLEL=10", ()
- )
+ with (
+ mock.patch("snowflake.connector.connect") as mock_connect,
+ Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf,
+ ):
+ sf.download_from_internal("@~/internal", r"C:\some\file")
+ sf.conn.cursor.return_value.execute.assert_called_with(
+ "GET @~/internal 'file://C:/some/file' PARALLEL=10", ()
+ )
@pytest.mark.parametrize(
@@ -275,42 +282,40 @@ def test_download_from_internal_windows(mock_session, sf_credentials):
def test_copy(
mock_session, file_type, format_options, copy_options, expected, sf_credentials
):
- with mock.patch("snowflake.connector.connect") as mock_connect:
- with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf:
-
- sf.copy(
- "table_name",
- "@~/stage",
- file_type=file_type,
- format_options=format_options,
- copy_options=copy_options,
- )
- sf.conn.cursor.return_value.execute.assert_called_with(
- "COPY INTO table_name FROM '@~/stage' FILE_FORMAT = {0}".format(
- expected
- ),
- (),
- )
+ with (
+ mock.patch("snowflake.connector.connect") as mock_connect,
+ Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf,
+ ):
+ sf.copy(
+ "table_name",
+ "@~/stage",
+ file_type=file_type,
+ format_options=format_options,
+ copy_options=copy_options,
+ )
+ sf.conn.cursor.return_value.execute.assert_called_with(
+ f"COPY INTO table_name FROM '@~/stage' FILE_FORMAT = {expected}",
+ (),
+ )
@mock.patch("locopy.s3.Session")
def test_copy_exception(mock_session, sf_credentials):
- with mock.patch("snowflake.connector.connect") as mock_connect:
- with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf:
-
- with pytest.raises(ValueError):
- sf.copy("table_name", "@~/stage", file_type="unknown")
-
- # exception
- sf.conn.cursor.return_value.execute.side_effect = Exception(
- "COPY Exception"
- )
- with pytest.raises(DBError):
- sf.copy("table_name", "@~/stage")
+ with (
+ mock.patch("snowflake.connector.connect") as mock_connect,
+ Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf,
+ ):
+ with pytest.raises(ValueError):
+ sf.copy("table_name", "@~/stage", file_type="unknown")
+
+ # exception
+ sf.conn.cursor.return_value.execute.side_effect = Exception("COPY Exception")
+ with pytest.raises(DBError):
+ sf.copy("table_name", "@~/stage")
- sf.conn = None
- with pytest.raises(DBError):
- sf.copy("table_name", "@~/stage")
+ sf.conn = None
+ with pytest.raises(DBError):
+ sf.copy("table_name", "@~/stage")
@pytest.mark.parametrize(
@@ -359,41 +364,41 @@ def test_unload(
expected,
sf_credentials,
):
- with mock.patch("snowflake.connector.connect") as mock_connect:
- with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf:
-
- sf.unload(
- "@~/stage",
- "table_name",
- file_type=file_type,
- format_options=format_options,
- header=header,
- copy_options=copy_options,
- )
- sf.conn.cursor.return_value.execute.assert_called_with(
- "COPY INTO @~/stage FROM table_name FILE_FORMAT = {0}".format(expected),
- (),
- )
+ with (
+ mock.patch("snowflake.connector.connect") as mock_connect,
+ Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf,
+ ):
+ sf.unload(
+ "@~/stage",
+ "table_name",
+ file_type=file_type,
+ format_options=format_options,
+ header=header,
+ copy_options=copy_options,
+ )
+ sf.conn.cursor.return_value.execute.assert_called_with(
+ f"COPY INTO @~/stage FROM table_name FILE_FORMAT = {expected}",
+ (),
+ )
@mock.patch("locopy.s3.Session")
def test_unload_exception(mock_session, sf_credentials):
- with mock.patch("snowflake.connector.connect") as mock_connect:
- with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf:
-
- with pytest.raises(ValueError):
- sf.unload("table_name", "@~/stage", file_type="unknown")
-
- # exception
- sf.conn.cursor.return_value.execute.side_effect = Exception(
- "UNLOAD Exception"
- )
- with pytest.raises(DBError):
- sf.unload("@~/stage", "table_name")
+ with (
+ mock.patch("snowflake.connector.connect") as mock_connect,
+ Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf,
+ ):
+ with pytest.raises(ValueError):
+ sf.unload("table_name", "@~/stage", file_type="unknown")
+
+ # exception
+ sf.conn.cursor.return_value.execute.side_effect = Exception("UNLOAD Exception")
+ with pytest.raises(DBError):
+ sf.unload("@~/stage", "table_name")
- sf.conn = None
- with pytest.raises(DBError):
- sf.unload("@~/stage", "table_name")
+ sf.conn = None
+ with pytest.raises(DBError):
+ sf.unload("@~/stage", "table_name")
@mock.patch("locopy.s3.Session")
@@ -401,18 +406,20 @@ def test_to_pandas(mock_session, sf_credentials):
import pandas as pd
test_df = pd.read_csv(os.path.join(CURR_DIR, "data", "mock_dataframe.txt"), sep=",")
- with mock.patch("snowflake.connector.connect") as mock_connect:
- with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf:
- sf.cursor._query_result_format = "arrow"
- sf.to_dataframe()
- sf.conn.cursor.return_value.fetch_pandas_all.assert_called_with()
+ with (
+ mock.patch("snowflake.connector.connect") as mock_connect,
+ Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf,
+ ):
+ sf.cursor._query_result_format = "arrow"
+ sf.to_dataframe()
+ sf.conn.cursor.return_value.fetch_pandas_all.assert_called_with()
- sf.cursor._query_result_format = "json"
- sf.to_dataframe()
- sf.conn.cursor.return_value.fetchall.assert_called_with()
+ sf.cursor._query_result_format = "json"
+ sf.to_dataframe()
+ sf.conn.cursor.return_value.fetchall.assert_called_with()
- sf.to_dataframe(5)
- sf.conn.cursor.return_value.fetchmany.assert_called_with(5)
+ sf.to_dataframe(5)
+ sf.conn.cursor.return_value.fetchmany.assert_called_with(5)
@mock.patch("locopy.s3.Session")
@@ -420,61 +427,63 @@ def test_insert_dataframe_to_table(mock_session, sf_credentials):
import pandas as pd
test_df = pd.read_csv(os.path.join(CURR_DIR, "data", "mock_dataframe.txt"), sep=",")
- with mock.patch("snowflake.connector.connect") as mock_connect:
- with Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf:
- sf.insert_dataframe_to_table(test_df, "database.schema.test")
- sf.conn.cursor.return_value.executemany.assert_called_with(
- "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)",
- [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")],
- )
+ with (
+ mock.patch("snowflake.connector.connect") as mock_connect,
+ Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf,
+ ):
+ sf.insert_dataframe_to_table(test_df, "database.schema.test")
+ sf.conn.cursor.return_value.executemany.assert_called_with(
+ "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)",
+ [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")],
+ )
- sf.insert_dataframe_to_table(test_df, "database.schema.test", create=True)
- sf.conn.cursor.return_value.execute.assert_any_call(
- "CREATE TABLE database.schema.test (a int,b varchar,c date)", ()
- )
- sf.conn.cursor.return_value.executemany.assert_called_with(
- "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)",
- [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")],
- )
+ sf.insert_dataframe_to_table(test_df, "database.schema.test", create=True)
+ sf.conn.cursor.return_value.execute.assert_any_call(
+ "CREATE TABLE database.schema.test (a int,b varchar,c date)", ()
+ )
+ sf.conn.cursor.return_value.executemany.assert_called_with(
+ "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)",
+ [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")],
+ )
- sf.insert_dataframe_to_table(
- test_df, "database.schema.test", columns=["a", "b"]
- )
+ sf.insert_dataframe_to_table(
+ test_df, "database.schema.test", columns=["a", "b"]
+ )
- sf.conn.cursor.return_value.executemany.assert_called_with(
- "INSERT INTO database.schema.test (a,b) VALUES (%s,%s)",
- [("1", "x"), ("2", "y")],
- )
+ sf.conn.cursor.return_value.executemany.assert_called_with(
+ "INSERT INTO database.schema.test (a,b) VALUES (%s,%s)",
+ [("1", "x"), ("2", "y")],
+ )
- sf.insert_dataframe_to_table(
- test_df,
- "database.schema.test",
- create=True,
- metadata=OrderedDict(
- [("col1", "int"), ("col2", "varchar"), ("col3", "date")]
- ),
- )
+ sf.insert_dataframe_to_table(
+ test_df,
+ "database.schema.test",
+ create=True,
+ metadata=OrderedDict(
+ [("col1", "int"), ("col2", "varchar"), ("col3", "date")]
+ ),
+ )
- sf.conn.cursor.return_value.execute.assert_any_call(
- "CREATE TABLE database.schema.test (col1 int,col2 varchar,col3 date)",
- (),
- )
- sf.conn.cursor.return_value.executemany.assert_called_with(
- "INSERT INTO database.schema.test (col1,col2,col3) VALUES (%s,%s,%s)",
- [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")],
- )
+ sf.conn.cursor.return_value.execute.assert_any_call(
+ "CREATE TABLE database.schema.test (col1 int,col2 varchar,col3 date)",
+ (),
+ )
+ sf.conn.cursor.return_value.executemany.assert_called_with(
+ "INSERT INTO database.schema.test (col1,col2,col3) VALUES (%s,%s,%s)",
+ [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")],
+ )
- sf.insert_dataframe_to_table(
- test_df,
- "database.schema.test",
- create=False,
- metadata=OrderedDict(
- [("col1", "int"), ("col2", "varchar"), ("col3", "date")]
- ),
- )
+ sf.insert_dataframe_to_table(
+ test_df,
+ "database.schema.test",
+ create=False,
+ metadata=OrderedDict(
+ [("col1", "int"), ("col2", "varchar"), ("col3", "date")]
+ ),
+ )
- # mock_session.warn.assert_called_with('Metadata will not be used because create is set to False.')
- sf.conn.cursor.return_value.executemany.assert_called_with(
- "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)",
- [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")],
- )
+ # mock_session.warn.assert_called_with('Metadata will not be used because create is set to False.')
+ sf.conn.cursor.return_value.executemany.assert_called_with(
+ "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)",
+ [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")],
+ )
diff --git a/tests/test_utility.py b/tests/test_utility.py
index bd32612..ec6f8f2 100644
--- a/tests/test_utility.py
+++ b/tests/test_utility.py
@@ -20,23 +20,31 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
+import datetime
import os
import sys
from io import StringIO
from itertools import cycle
from pathlib import Path
from unittest import mock
-import datetime
-
-import pytest
import locopy.utility as util
-from locopy.errors import (CompressionError, CredentialsError,
- LocopyConcatError, LocopyIgnoreHeaderError,
- LocopySplitError)
-from locopy.utility import (compress_file, compress_file_list,
- concatenate_files, find_column_type,
- get_ignoreheader_number, split_file)
+import pytest
+from locopy.errors import (
+ CompressionError,
+ CredentialsError,
+ LocopyConcatError,
+ LocopyIgnoreHeaderError,
+ LocopySplitError,
+)
+from locopy.utility import (
+ compress_file,
+ compress_file_list,
+ concatenate_files,
+ find_column_type,
+ get_ignoreheader_number,
+ split_file,
+)
GOOD_CONFIG_YAML = """host: my.redshift.cluster.com
port: 1234
@@ -56,7 +64,7 @@ def cleanup(splits):
def compare_file_contents(base_file, check_files):
- check_files = cycle([open(x, "rb") for x in check_files])
+ check_files = cycle([open(x, "rb") for x in check_files]) # noqa: SIM115
with open(base_file, "rb") as base:
for line in base:
cfile = next(check_files)
@@ -211,7 +219,7 @@ def test_split_file_exception():
with pytest.raises(LocopySplitError):
split_file(input_file, output_file, "Test")
- with mock.patch("{0}.next".format(builtin_module_name)) as mock_next:
+ with mock.patch(f"{builtin_module_name}.next") as mock_next:
mock_next.side_effect = Exception("SomeException")
with pytest.raises(LocopySplitError):
@@ -229,7 +237,7 @@ def test_split_file_exception():
@mock.patch("locopy.utility.open", mock.mock_open(read_data=GOOD_CONFIG_YAML))
def test_read_config_yaml_good():
actual = util.read_config_yaml("filename.yml")
- assert set(actual.keys()) == set(["host", "port", "database", "user", "password"])
+ assert set(actual.keys()) == {"host", "port", "database", "user", "password"}
assert actual["host"] == "my.redshift.cluster.com"
assert actual["port"] == 1234
assert actual["database"] == "db"
@@ -239,7 +247,7 @@ def test_read_config_yaml_good():
def test_read_config_yaml_io():
actual = util.read_config_yaml(StringIO(GOOD_CONFIG_YAML))
- assert set(actual.keys()) == set(["host", "port", "database", "user", "password"])
+ assert set(actual.keys()) == {"host", "port", "database", "user", "password"}
assert actual["host"] == "my.redshift.cluster.com"
assert actual["port"] == 1234
assert actual["database"] == "db"
@@ -258,7 +266,8 @@ def test_concatenate_files():
with mock.patch("locopy.utility.os.remove") as mock_remove:
concatenate_files(inputs, output)
assert mock_remove.call_count == 3
- assert [int(line.rstrip("\n")) for line in open(output)] == list(range(1, 16))
+ with open(output) as f:
+ assert [int(line.rstrip("\n")) for line in f] == list(range(1, 16))
os.remove(output)
@@ -275,7 +284,6 @@ def test_concatenate_files_exception():
def test_find_column_type():
-
from decimal import Decimal
import pandas as pd
@@ -341,29 +349,27 @@ def test_find_column_type():
assert find_column_type(input_text, "snowflake") == output_text_snowflake
assert find_column_type(input_text, "redshift") == output_text_redshift
-def test_find_column_type_new():
-
- from decimal import Decimal
+def test_find_column_type_new():
import pandas as pd
input_text = pd.DataFrame.from_dict(
- {
- "a": [1],
- "b": [pd.Timestamp('2017-01-01T12+0')],
- "c": [1.2],
- "d": ["a"],
- "e": [True]
- }
-)
+ {
+ "a": [1],
+ "b": [pd.Timestamp("2017-01-01T12+0")],
+ "c": [1.2],
+ "d": ["a"],
+ "e": [True],
+ }
+ )
input_text = input_text.astype(
dtype={
- "a": pd.Int64Dtype(),
- "b": pd.DatetimeTZDtype(tz=datetime.timezone.utc),
- "c": pd.Float64Dtype(),
- "d": pd.StringDtype(),
- "e": pd.BooleanDtype()
+ "a": pd.Int64Dtype(),
+ "b": pd.DatetimeTZDtype(tz=datetime.timezone.utc),
+ "c": pd.Float64Dtype(),
+ "d": pd.StringDtype(),
+ "e": pd.BooleanDtype(),
}
)
@@ -375,7 +381,7 @@ def test_find_column_type_new():
"e": "boolean",
}
- output_text_redshift = {
+ output_text_redshift = {
"a": "int",
"b": "timestamp",
"c": "float",
@@ -387,7 +393,6 @@ def test_find_column_type_new():
assert find_column_type(input_text, "redshift") == output_text_redshift
-
def test_get_ignoreheader_number():
assert (
get_ignoreheader_number(