diff --git a/.github/workflows/edgetest.yml b/.github/workflows/edgetest.yml
index b57d7e9..f699c1f 100644
--- a/.github/workflows/edgetest.yml
+++ b/.github/workflows/edgetest.yml
@@ -19,8 +19,8 @@ jobs:
cp tests/data/.locopyrc ~/.locopyrc
cp tests/data/.locopy-sfrc ~/.locopy-sfrc
- id: run-edgetest
- uses: fdosani/run-edgetest-action@v1.0
+ uses: fdosani/run-edgetest-action@v1.3
with:
edgetest-flags: '-c pyproject.toml --export'
base-branch: 'develop'
- skip-pr: 'false'
\ No newline at end of file
+ skip-pr: 'false'
diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml
index c8c5497..51fbb67 100644
--- a/.github/workflows/publish-package.yml
+++ b/.github/workflows/publish-package.yml
@@ -18,7 +18,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
- python-version: '3.8'
+ python-version: '3.10'
- name: Install dependencies
run: python -m pip install -r requirements.txt .[dev]
- name: Build and publish
diff --git a/CODEOWNERS b/CODEOWNERS
index a109fdd..53ed855 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1 +1 @@
-* @fdosani @NikhilJArora @ak-gupta
+* @fdosani @ak-gupta @jdawang @gladysteh99 @NikhilJArora
diff --git a/locopy/_version.py b/locopy/_version.py
index 3a2d64c..350b03d 100644
--- a/locopy/_version.py
+++ b/locopy/_version.py
@@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "0.5.3"
+__version__ = "0.5.4"
diff --git a/locopy/redshift.py b/locopy/redshift.py
index 3a3cb4e..eeabead 100644
--- a/locopy/redshift.py
+++ b/locopy/redshift.py
@@ -25,14 +25,8 @@
from .errors import DBError, S3CredentialsError
from .logger import INFO, get_logger
from .s3 import S3
-from .utility import (
- compress_file_list,
- concatenate_files,
- find_column_type,
- get_ignoreheader_number,
- split_file,
- write_file,
-)
+from .utility import (compress_file_list, concatenate_files, find_column_type,
+ get_ignoreheader_number, split_file, write_file)
logger = get_logger(__name__, INFO)
@@ -153,11 +147,15 @@ class Redshift(S3, Database):
Issue initializing S3 session
"""
- def __init__(self, profile=None, kms_key=None, dbapi=None, config_yaml=None, **kwargs):
+ def __init__(
+ self, profile=None, kms_key=None, dbapi=None, config_yaml=None, **kwargs
+ ):
try:
S3.__init__(self, profile, kms_key)
except S3CredentialsError:
- logger.warning("S3 credentials were not found. S3 functionality is disabled")
+ logger.warning(
+ "S3 credentials were not found. S3 functionality is disabled"
+ )
Database.__init__(self, dbapi, config_yaml, **kwargs)
def connect(self):
@@ -304,7 +302,9 @@ def load_and_copy(
if splits > 1 and ignore_header > 0:
# remove the IGNOREHEADER from copy_options
logger.info("Removing the IGNOREHEADER option as split is enabled")
- copy_options = [i for i in copy_options if not i.startswith("IGNOREHEADER ")]
+ copy_options = [
+ i for i in copy_options if not i.startswith("IGNOREHEADER ")
+ ]
if compress:
copy_options.append("GZIP")
@@ -448,11 +448,16 @@ def unload(self, query, s3path, unload_options=None):
unload_options = unload_options or []
unload_options_text = " ".join(unload_options)
- base_unload_string = "UNLOAD ('{0}')\n" "TO '{1}'\n" "CREDENTIALS '{2}'\n" "{3};"
+ base_unload_string = (
+ "UNLOAD ('{0}')\n" "TO '{1}'\n" "CREDENTIALS '{2}'\n" "{3};"
+ )
try:
sql = base_unload_string.format(
- query.replace("'", r"\'"), s3path, self._credentials_string(), unload_options_text
+ query.replace("'", r"\'"),
+ s3path,
+ self._credentials_string(),
+ unload_options_text,
)
self.execute(sql, commit=True)
except Exception as e:
@@ -494,7 +499,10 @@ def _unload_generated_files(self):
list
List of S3 file names
"""
- sql = "SELECT path FROM stl_unload_log " "WHERE query = pg_last_query_id() ORDER BY path"
+ sql = (
+ "SELECT path FROM stl_unload_log "
+ "WHERE query = pg_last_query_id() ORDER BY path"
+ )
try:
logger.info("Getting list of unloaded files")
self.execute(sql)
@@ -563,7 +571,7 @@ def insert_dataframe_to_table(
if create:
if not metadata:
logger.info("Metadata is missing. Generating metadata ...")
- metadata = find_column_type(dataframe)
+ metadata = find_column_type(dataframe, "redshift")
logger.info("Metadata is complete. Creating new table ...")
create_join = (
@@ -592,7 +600,9 @@ def insert_dataframe_to_table(
"("
+ ", ".join(
[
- "NULL" if pd.isnull(val) else "'" + str(val).replace("'", "''") + "'"
+ "NULL"
+ if pd.isnull(val)
+ else "'" + str(val).replace("'", "''") + "'"
for val in row
]
)
@@ -600,8 +610,10 @@ def insert_dataframe_to_table(
)
to_insert.append(none_row)
string_join = ", ".join(to_insert)
- insert_query = """INSERT INTO {table_name} {columns} VALUES {values}""".format(
- table_name=table_name, columns=column_sql, values=string_join
+ insert_query = (
+ """INSERT INTO {table_name} {columns} VALUES {values}""".format(
+ table_name=table_name, columns=column_sql, values=string_join
+ )
)
self.execute(insert_query, verbose=verbose)
logger.info("Table insertion has completed")
diff --git a/locopy/snowflake.py b/locopy/snowflake.py
index 6d1ea0d..60d3c5e 100644
--- a/locopy/snowflake.py
+++ b/locopy/snowflake.py
@@ -85,7 +85,7 @@
def combine_options(options=None):
- """ Returns the ``copy_options`` or ``format_options`` attribute with spaces in between and as
+ """Returns the ``copy_options`` or ``format_options`` attribute with spaces in between and as
a string. If options is ``None`` then return an empty string.
Parameters
@@ -170,11 +170,15 @@ class Snowflake(S3, Database):
Issue initializing S3 session
"""
- def __init__(self, profile=None, kms_key=None, dbapi=None, config_yaml=None, **kwargs):
+ def __init__(
+ self, profile=None, kms_key=None, dbapi=None, config_yaml=None, **kwargs
+ ):
try:
S3.__init__(self, profile, kms_key)
except S3CredentialsError:
- logger.warning("S3 credentials were not found. S3 functionality is disabled")
+ logger.warning(
+ "S3 credentials were not found. S3 functionality is disabled"
+ )
logger.warning("Only internal stages are available")
Database.__init__(self, dbapi, config_yaml, **kwargs)
@@ -196,7 +200,9 @@ def connect(self):
if self.connection.get("schema") is not None:
self.execute("USE SCHEMA {0}".format(self.connection["schema"]))
- def upload_to_internal(self, local, stage, parallel=4, auto_compress=True, overwrite=True):
+ def upload_to_internal(
+ self, local, stage, parallel=4, auto_compress=True, overwrite=True
+ ):
"""
Upload file(s) to a internal stage via the ``PUT`` command.
@@ -249,9 +255,13 @@ def download_from_internal(self, stage, local=None, parallel=10):
if local is None:
local = os.getcwd()
local_uri = PurePath(local).as_posix()
- self.execute("GET {0} 'file://{1}' PARALLEL={2}".format(stage, local_uri, parallel))
+ self.execute(
+ "GET {0} 'file://{1}' PARALLEL={2}".format(stage, local_uri, parallel)
+ )
- def copy(self, table_name, stage, file_type="csv", format_options=None, copy_options=None):
+ def copy(
+ self, table_name, stage, file_type="csv", format_options=None, copy_options=None
+ ):
"""Executes the ``COPY INTO
`` command to load CSV files from a stage into
a Snowflake table. If ``file_type == csv`` and ``format_options == None``, ``format_options``
will default to: ``["FIELD_DELIMITER='|'", "SKIP_HEADER=0"]``
@@ -286,7 +296,9 @@ def copy(self, table_name, stage, file_type="csv", format_options=None, copy_opt
if file_type not in COPY_FORMAT_OPTIONS:
raise ValueError(
- "Invalid file_type. Must be one of {0}".format(list(COPY_FORMAT_OPTIONS.keys()))
+ "Invalid file_type. Must be one of {0}".format(
+ list(COPY_FORMAT_OPTIONS.keys())
+ )
)
if format_options is None and file_type == "csv":
@@ -294,7 +306,9 @@ def copy(self, table_name, stage, file_type="csv", format_options=None, copy_opt
format_options_text = combine_options(format_options)
copy_options_text = combine_options(copy_options)
- base_copy_string = "COPY INTO {0} FROM '{1}' " "FILE_FORMAT = (TYPE='{2}' {3}) {4}"
+ base_copy_string = (
+ "COPY INTO {0} FROM '{1}' " "FILE_FORMAT = (TYPE='{2}' {3}) {4}"
+ )
try:
sql = base_copy_string.format(
table_name, stage, file_type, format_options_text, copy_options_text
@@ -350,7 +364,9 @@ def unload(
if file_type not in COPY_FORMAT_OPTIONS:
raise ValueError(
- "Invalid file_type. Must be one of {0}".format(list(UNLOAD_FORMAT_OPTIONS.keys()))
+ "Invalid file_type. Must be one of {0}".format(
+ list(UNLOAD_FORMAT_OPTIONS.keys())
+ )
)
if format_options is None and file_type == "csv":
@@ -364,7 +380,12 @@ def unload(
try:
sql = base_unload_string.format(
- stage, table_name, file_type, format_options_text, header, copy_options_text
+ stage,
+ table_name,
+ file_type,
+ format_options_text,
+ header,
+ copy_options_text,
)
self.execute(sql, commit=True)
except Exception as e:
@@ -422,7 +443,7 @@ def insert_dataframe_to_table(
if create:
if not metadata:
logger.info("Metadata is missing. Generating metadata ...")
- metadata = find_column_type(dataframe)
+ metadata = find_column_type(dataframe, "snowflake")
logger.info("Metadata is complete. Creating new table ...")
create_join = (
diff --git a/locopy/utility.py b/locopy/utility.py
index 7515c1d..d19bf60 100644
--- a/locopy/utility.py
+++ b/locopy/utility.py
@@ -27,13 +27,8 @@
import yaml
-from .errors import (
- CompressionError,
- CredentialsError,
- LocopyConcatError,
- LocopyIgnoreHeaderError,
- LocopySplitError,
-)
+from .errors import (CompressionError, CredentialsError, LocopyConcatError,
+ LocopyIgnoreHeaderError, LocopySplitError)
from .logger import INFO, get_logger
logger = get_logger(__name__, INFO)
@@ -250,7 +245,7 @@ def read_config_yaml(config_yaml):
# make it more granular, eg. include length
-def find_column_type(dataframe):
+def find_column_type(dataframe, warehouse_type: str):
"""
Find data type of each column from the dataframe.
@@ -271,6 +266,9 @@ def find_column_type(dataframe):
----------
dataframe : Pandas dataframe
+ warehouse_type: str
+ Required to properly determine format of uploaded data, either "snowflake" or "redshift".
+
Returns
-------
dict
@@ -284,10 +282,16 @@ def find_column_type(dataframe):
def validate_date_object(column):
try:
pd.to_datetime(column)
- if re.search(r"\d+:\d+:\d+", column.sample(1).to_string(index=False)):
+ sample_data = column.sample(1).to_string(index=False)
+ if re.search(r"\d+:\d+:\d+", sample_data):
return "timestamp"
- else:
+ elif warehouse_type == "redshift" or re.search(
+ r"(\d{4}-\d{2}-\d{2})|(\d{2}-[A-Z]{3}-\d{4})|(\d{2}/\d{2}/\d{4})",
+ sample_data,
+ ):
return "date"
+ else:
+ return "varchar"
except (ValueError, TypeError):
return None
@@ -298,6 +302,11 @@ def validate_float_object(column):
except (ValueError, TypeError):
return None
+ if warehouse_type.lower() not in ["snowflake", "redshift"]:
+ raise ValueError(
+ 'warehouse_type argument must be either "snowflake" or "redshift"'
+ )
+
column_type = []
for column in dataframe.columns:
logger.debug("Checking column: %s", column)
diff --git a/pyproject.toml b/pyproject.toml
index ab0b470..9eb3062 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,12 +6,7 @@ authors = [
{ name="Faisal Dosani", email="faisal.dosani@capitalone.com" },
]
license = {text = "Apache Software License"}
-dependencies = [
- "boto3<=1.28.3,>=1.9.92",
- "PyYAML<=6.0,>=5.1",
- "pandas<=2.0.3,>=0.25.2",
- "numpy<=1.25.1,>=1.22.0",
-]
+dependencies = ["boto3<=1.28.39,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.1.0,>=0.25.2", "numpy<=1.25.2,>=1.22.0"]
requires-python = ">=3.8.0"
classifiers = [
@@ -45,24 +40,11 @@ psycopg2 = ["psycopg2-binary>=2.7.7"]
pg8000 = ["pg8000>=1.13.1"]
snowflake = ["snowflake-connector-python[pandas]>=2.1.2"]
docs = ["sphinx", "sphinx_rtd_theme"]
-tests = [
- "hypothesis",
- "pytest",
- "pytest-cov",
-]
-qa = [
- "pre-commit",
- "black",
- "isort",
-]
+tests = ["hypothesis", "pytest", "pytest-cov"]
+qa = ["pre-commit", "black", "isort"]
build = ["build", "twine", "wheel"]
edgetest = ["edgetest", "edgetest-conda"]
-dev = [
- "locopy[tests]",
- "locopy[docs]",
- "locopy[qa]",
- "locopy[build]",
-]
+dev = ["locopy[tests]", "locopy[docs]", "locopy[qa]", "locopy[build]"]
[isort]
multi_line_output = 3
diff --git a/tests/test_utility.py b/tests/test_utility.py
index f2f0475..e4139ac 100644
--- a/tests/test_utility.py
+++ b/tests/test_utility.py
@@ -24,29 +24,20 @@
import sys
from io import StringIO
from itertools import cycle
-from unittest import mock
from pathlib import Path
+from unittest import mock
import pytest
import locopy.utility as util
-from locopy.errors import (
- CompressionError,
- CredentialsError,
- LocopyConcatError,
- LocopyIgnoreHeaderError,
- LocopySplitError,
-)
-from locopy.utility import (
- compress_file,
- compress_file_list,
- concatenate_files,
- find_column_type,
- get_ignoreheader_number,
- split_file,
-)
-
-GOOD_CONFIG_YAML = u"""host: my.redshift.cluster.com
+from locopy.errors import (CompressionError, CredentialsError,
+ LocopyConcatError, LocopyIgnoreHeaderError,
+ LocopySplitError)
+from locopy.utility import (compress_file, compress_file_list,
+ concatenate_files, find_column_type,
+ get_ignoreheader_number, split_file)
+
+GOOD_CONFIG_YAML = """host: my.redshift.cluster.com
port: 1234
database: db
user: userid
@@ -81,7 +72,9 @@ def test_compress_file(mock_shutil, mock_gzip_open, mock_open):
compress_file("input", "output")
mock_open.assert_called_with("input", "rb")
mock_gzip_open.assert_called_with("output", "wb")
- mock_shutil.assert_called_with(mock_open().__enter__(), mock_gzip_open().__enter__())
+ mock_shutil.assert_called_with(
+ mock_open().__enter__(), mock_gzip_open().__enter__()
+ )
@mock.patch("locopy.utility.open")
@@ -110,7 +103,9 @@ def test_compress_file_list(mock_shutil, mock_gzip_open, mock_open, mock_remove)
@mock.patch("locopy.utility.open")
@mock.patch("locopy.utility.gzip.open")
@mock.patch("locopy.utility.shutil.copyfileobj")
-def test_compress_file_list_exception(mock_shutil, mock_gzip_open, mock_open, mock_remove):
+def test_compress_file_list_exception(
+ mock_shutil, mock_gzip_open, mock_open, mock_remove
+):
mock_shutil.side_effect = Exception("SomeException")
with pytest.raises(CompressionError):
compress_file_list(["input1", "input2"])
@@ -123,7 +118,10 @@ def test_split_file():
splits = split_file(input_file, output_file)
assert splits == [input_file]
- expected = ["tests/data/mock_output_file.txt.0", "tests/data/mock_output_file.txt.1"]
+ expected = [
+ "tests/data/mock_output_file.txt.0",
+ "tests/data/mock_output_file.txt.1",
+ ]
splits = split_file(input_file, output_file, 2)
assert splits == expected
assert compare_file_contents(input_file, expected)
@@ -213,18 +211,18 @@ def test_split_file_exception():
split_file(input_file, output_file, "Test")
with mock.patch("{0}.next".format(builtin_module_name)) as mock_next:
- mock_next.side_effect = Exception("SomeException")
+ mock_next.side_effect = Exception("SomeException")
- with pytest.raises(LocopySplitError):
- split_file(input_file, output_file, 2)
- assert not Path("tests/data/mock_output_file.txt.0").exists()
- assert not Path("tests/data/mock_output_file.txt.1").exists()
+ with pytest.raises(LocopySplitError):
+ split_file(input_file, output_file, 2)
+ assert not Path("tests/data/mock_output_file.txt.0").exists()
+ assert not Path("tests/data/mock_output_file.txt.1").exists()
- with pytest.raises(LocopySplitError):
- split_file(input_file, output_file, 3)
- assert not Path("tests/data/mock_output_file.txt.0").exists()
- assert not Path("tests/data/mock_output_file.txt.1").exists()
- assert not Path("tests/data/mock_output_file.txt.2").exists()
+ with pytest.raises(LocopySplitError):
+ split_file(input_file, output_file, 3)
+ assert not Path("tests/data/mock_output_file.txt.0").exists()
+ assert not Path("tests/data/mock_output_file.txt.1").exists()
+ assert not Path("tests/data/mock_output_file.txt.2").exists()
@mock.patch("locopy.utility.open", mock.mock_open(read_data=GOOD_CONFIG_YAML))
@@ -277,9 +275,10 @@ def test_concatenate_files_exception():
def test_find_column_type():
- import pandas as pd
from decimal import Decimal
+ import pandas as pd
+
# add timestamp
input_text = pd.DataFrame.from_dict(
{
@@ -299,9 +298,29 @@ def test_find_column_type():
"j": [None, "2011-01-01 12:11:02", "2022-03-02 23:59:59"],
"k": [Decimal(3.3), Decimal(100), None],
"l": pd.Series([1, 2, 3], dtype="category"),
+ "m": ["2022-02", "2022-03", "2020-02"],
+ "n": ["2020q1", "2021q2", "2022q3"],
+ "o": ["10-DEC-2022", "11-NOV-2020", "10-OCT-2020"],
}
)
- output_text = {
+ output_text_snowflake = {
+ "a": "int",
+ "b": "varchar",
+ "c": "varchar",
+ "d": "date",
+ "e": "float",
+ "f": "float",
+ "g": "varchar",
+ "h": "timestamp",
+ "i": "date",
+ "j": "timestamp",
+ "k": "float",
+ "l": "varchar",
+ "m": "varchar",
+ "n": "varchar",
+ "o": "date",
+ }
+ output_text_redshift = {
"a": "int",
"b": "varchar",
"c": "varchar",
@@ -314,26 +333,45 @@ def test_find_column_type():
"j": "timestamp",
"k": "float",
"l": "varchar",
+ "m": "date",
+ "n": "date",
+ "o": "date",
}
- assert find_column_type(input_text) == output_text
+ assert find_column_type(input_text, "snowflake") == output_text_snowflake
+ assert find_column_type(input_text, "redshift") == output_text_redshift
def test_get_ignoreheader_number():
assert (
get_ignoreheader_number(
- ["DATEFORMAT 'auto'", "COMPUPDATE ON", "TRUNCATECOLUMNS", "IGNOREHEADER as 1"]
+ [
+ "DATEFORMAT 'auto'",
+ "COMPUPDATE ON",
+ "TRUNCATECOLUMNS",
+ "IGNOREHEADER as 1",
+ ]
)
== 1
)
assert (
get_ignoreheader_number(
- ["DATEFORMAT 'auto'", "COMPUPDATE ON", "TRUNCATECOLUMNS", "IGNOREHEADER as 2"]
+ [
+ "DATEFORMAT 'auto'",
+ "COMPUPDATE ON",
+ "TRUNCATECOLUMNS",
+ "IGNOREHEADER as 2",
+ ]
)
== 2
)
assert (
get_ignoreheader_number(
- ["DATEFORMAT 'auto'", "COMPUPDATE ON", "TRUNCATECOLUMNS", "IGNOREHEADER as 99"]
+ [
+ "DATEFORMAT 'auto'",
+ "COMPUPDATE ON",
+ "TRUNCATECOLUMNS",
+ "IGNOREHEADER as 99",
+ ]
)
== 99
)
@@ -359,33 +397,63 @@ def test_get_ignoreheader_number():
assert (
get_ignoreheader_number(
- ["DATEFORMAT 'auto'", "COMPUPDATE ON", "TRUNCATECOLUMNS", "IGNOREHEADER is 1"]
+ [
+ "DATEFORMAT 'auto'",
+ "COMPUPDATE ON",
+ "TRUNCATECOLUMNS",
+ "IGNOREHEADER is 1",
+ ]
)
== 1
)
assert (
get_ignoreheader_number(
- ["DATEFORMAT 'auto'", "COMPUPDATE ON", "TRUNCATECOLUMNS", "IGNOREHEADER is 2"]
+ [
+ "DATEFORMAT 'auto'",
+ "COMPUPDATE ON",
+ "TRUNCATECOLUMNS",
+ "IGNOREHEADER is 2",
+ ]
)
== 2
)
assert (
get_ignoreheader_number(
- ["DATEFORMAT 'auto'", "COMPUPDATE ON", "TRUNCATECOLUMNS", "IGNOREHEADER is 99"]
+ [
+ "DATEFORMAT 'auto'",
+ "COMPUPDATE ON",
+ "TRUNCATECOLUMNS",
+ "IGNOREHEADER is 99",
+ ]
)
== 99
)
- assert get_ignoreheader_number(["DATEFORMAT 'auto'", "COMPUPDATE ON", "TRUNCATECOLUMNS"]) == 0
assert (
get_ignoreheader_number(
- ["DATEFORMAT 'auto'", "COMPUPDATE ON", "TRUNCATECOLUMNS", "IGNOREHEADERAS 2"]
+ ["DATEFORMAT 'auto'", "COMPUPDATE ON", "TRUNCATECOLUMNS"]
+ )
+ == 0
+ )
+ assert (
+ get_ignoreheader_number(
+ [
+ "DATEFORMAT 'auto'",
+ "COMPUPDATE ON",
+ "TRUNCATECOLUMNS",
+ "IGNOREHEADERAS 2",
+ ]
)
== 0
)
assert (
get_ignoreheader_number(
- ["DATEFORMAT 'auto'", "COMPUPDATE ON", "TRUNCATECOLUMNS", "SOMETHINGIGNOREHEADER AS 2"]
+ [
+ "DATEFORMAT 'auto'",
+ "COMPUPDATE ON",
+ "TRUNCATECOLUMNS",
+ "SOMETHINGIGNOREHEADER AS 2",
+ ]
)
== 0
)