Skip to content

Commit

Permalink
Merge pull request #229 from capitalone/develop
Browse files Browse the repository at this point in the history
Release v0.5.4
  • Loading branch information
Faisal authored Sep 7, 2023
2 parents 0ab5ab6 + 6b0ba2e commit a1a2aba
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 109 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/edgetest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ jobs:
cp tests/data/.locopyrc ~/.locopyrc
cp tests/data/.locopy-sfrc ~/.locopy-sfrc
- id: run-edgetest
uses: fdosani/run-edgetest-action@v1.0
uses: fdosani/run-edgetest-action@v1.3
with:
edgetest-flags: '-c pyproject.toml --export'
base-branch: 'develop'
skip-pr: 'false'
skip-pr: 'false'
2 changes: 1 addition & 1 deletion .github/workflows/publish-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.8'
python-version: '3.10'
- name: Install dependencies
run: python -m pip install -r requirements.txt .[dev]
- name: Build and publish
Expand Down
2 changes: 1 addition & 1 deletion CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @fdosani @NikhilJArora @ak-gupta
* @fdosani @ak-gupta @jdawang @gladysteh99 @NikhilJArora
2 changes: 1 addition & 1 deletion locopy/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.5.3"
__version__ = "0.5.4"
48 changes: 30 additions & 18 deletions locopy/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,8 @@
from .errors import DBError, S3CredentialsError
from .logger import INFO, get_logger
from .s3 import S3
from .utility import (
compress_file_list,
concatenate_files,
find_column_type,
get_ignoreheader_number,
split_file,
write_file,
)
from .utility import (compress_file_list, concatenate_files, find_column_type,
get_ignoreheader_number, split_file, write_file)

logger = get_logger(__name__, INFO)

Expand Down Expand Up @@ -153,11 +147,15 @@ class Redshift(S3, Database):
Issue initializing S3 session
"""

def __init__(self, profile=None, kms_key=None, dbapi=None, config_yaml=None, **kwargs):
def __init__(
self, profile=None, kms_key=None, dbapi=None, config_yaml=None, **kwargs
):
try:
S3.__init__(self, profile, kms_key)
except S3CredentialsError:
logger.warning("S3 credentials were not found. S3 functionality is disabled")
logger.warning(
"S3 credentials were not found. S3 functionality is disabled"
)
Database.__init__(self, dbapi, config_yaml, **kwargs)

def connect(self):
Expand Down Expand Up @@ -304,7 +302,9 @@ def load_and_copy(
if splits > 1 and ignore_header > 0:
# remove the IGNOREHEADER from copy_options
logger.info("Removing the IGNOREHEADER option as split is enabled")
copy_options = [i for i in copy_options if not i.startswith("IGNOREHEADER ")]
copy_options = [
i for i in copy_options if not i.startswith("IGNOREHEADER ")
]

if compress:
copy_options.append("GZIP")
Expand Down Expand Up @@ -448,11 +448,16 @@ def unload(self, query, s3path, unload_options=None):

unload_options = unload_options or []
unload_options_text = " ".join(unload_options)
base_unload_string = "UNLOAD ('{0}')\n" "TO '{1}'\n" "CREDENTIALS '{2}'\n" "{3};"
base_unload_string = (
"UNLOAD ('{0}')\n" "TO '{1}'\n" "CREDENTIALS '{2}'\n" "{3};"
)

try:
sql = base_unload_string.format(
query.replace("'", r"\'"), s3path, self._credentials_string(), unload_options_text
query.replace("'", r"\'"),
s3path,
self._credentials_string(),
unload_options_text,
)
self.execute(sql, commit=True)
except Exception as e:
Expand Down Expand Up @@ -494,7 +499,10 @@ def _unload_generated_files(self):
list
List of S3 file names
"""
sql = "SELECT path FROM stl_unload_log " "WHERE query = pg_last_query_id() ORDER BY path"
sql = (
"SELECT path FROM stl_unload_log "
"WHERE query = pg_last_query_id() ORDER BY path"
)
try:
logger.info("Getting list of unloaded files")
self.execute(sql)
Expand Down Expand Up @@ -563,7 +571,7 @@ def insert_dataframe_to_table(
if create:
if not metadata:
logger.info("Metadata is missing. Generating metadata ...")
metadata = find_column_type(dataframe)
metadata = find_column_type(dataframe, "redshift")
logger.info("Metadata is complete. Creating new table ...")

create_join = (
Expand Down Expand Up @@ -592,16 +600,20 @@ def insert_dataframe_to_table(
"("
+ ", ".join(
[
"NULL" if pd.isnull(val) else "'" + str(val).replace("'", "''") + "'"
"NULL"
if pd.isnull(val)
else "'" + str(val).replace("'", "''") + "'"
for val in row
]
)
+ ")"
)
to_insert.append(none_row)
string_join = ", ".join(to_insert)
insert_query = """INSERT INTO {table_name} {columns} VALUES {values}""".format(
table_name=table_name, columns=column_sql, values=string_join
insert_query = (
"""INSERT INTO {table_name} {columns} VALUES {values}""".format(
table_name=table_name, columns=column_sql, values=string_join
)
)
self.execute(insert_query, verbose=verbose)
logger.info("Table insertion has completed")
43 changes: 32 additions & 11 deletions locopy/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@


def combine_options(options=None):
""" Returns the ``copy_options`` or ``format_options`` attribute with spaces in between and as
"""Returns the ``copy_options`` or ``format_options`` attribute with spaces in between and as
a string. If options is ``None`` then return an empty string.
Parameters
Expand Down Expand Up @@ -170,11 +170,15 @@ class Snowflake(S3, Database):
Issue initializing S3 session
"""

def __init__(self, profile=None, kms_key=None, dbapi=None, config_yaml=None, **kwargs):
def __init__(
self, profile=None, kms_key=None, dbapi=None, config_yaml=None, **kwargs
):
try:
S3.__init__(self, profile, kms_key)
except S3CredentialsError:
logger.warning("S3 credentials were not found. S3 functionality is disabled")
logger.warning(
"S3 credentials were not found. S3 functionality is disabled"
)
logger.warning("Only internal stages are available")
Database.__init__(self, dbapi, config_yaml, **kwargs)

Expand All @@ -196,7 +200,9 @@ def connect(self):
if self.connection.get("schema") is not None:
self.execute("USE SCHEMA {0}".format(self.connection["schema"]))

def upload_to_internal(self, local, stage, parallel=4, auto_compress=True, overwrite=True):
def upload_to_internal(
self, local, stage, parallel=4, auto_compress=True, overwrite=True
):
"""
Upload file(s) to a internal stage via the ``PUT`` command.
Expand Down Expand Up @@ -249,9 +255,13 @@ def download_from_internal(self, stage, local=None, parallel=10):
if local is None:
local = os.getcwd()
local_uri = PurePath(local).as_posix()
self.execute("GET {0} 'file://{1}' PARALLEL={2}".format(stage, local_uri, parallel))
self.execute(
"GET {0} 'file://{1}' PARALLEL={2}".format(stage, local_uri, parallel)
)

def copy(self, table_name, stage, file_type="csv", format_options=None, copy_options=None):
def copy(
self, table_name, stage, file_type="csv", format_options=None, copy_options=None
):
"""Executes the ``COPY INTO <table>`` command to load CSV files from a stage into
a Snowflake table. If ``file_type == csv`` and ``format_options == None``, ``format_options``
will default to: ``["FIELD_DELIMITER='|'", "SKIP_HEADER=0"]``
Expand Down Expand Up @@ -286,15 +296,19 @@ def copy(self, table_name, stage, file_type="csv", format_options=None, copy_opt

if file_type not in COPY_FORMAT_OPTIONS:
raise ValueError(
"Invalid file_type. Must be one of {0}".format(list(COPY_FORMAT_OPTIONS.keys()))
"Invalid file_type. Must be one of {0}".format(
list(COPY_FORMAT_OPTIONS.keys())
)
)

if format_options is None and file_type == "csv":
format_options = ["FIELD_DELIMITER='|'", "SKIP_HEADER=0"]

format_options_text = combine_options(format_options)
copy_options_text = combine_options(copy_options)
base_copy_string = "COPY INTO {0} FROM '{1}' " "FILE_FORMAT = (TYPE='{2}' {3}) {4}"
base_copy_string = (
"COPY INTO {0} FROM '{1}' " "FILE_FORMAT = (TYPE='{2}' {3}) {4}"
)
try:
sql = base_copy_string.format(
table_name, stage, file_type, format_options_text, copy_options_text
Expand Down Expand Up @@ -350,7 +364,9 @@ def unload(

if file_type not in COPY_FORMAT_OPTIONS:
raise ValueError(
"Invalid file_type. Must be one of {0}".format(list(UNLOAD_FORMAT_OPTIONS.keys()))
"Invalid file_type. Must be one of {0}".format(
list(UNLOAD_FORMAT_OPTIONS.keys())
)
)

if format_options is None and file_type == "csv":
Expand All @@ -364,7 +380,12 @@ def unload(

try:
sql = base_unload_string.format(
stage, table_name, file_type, format_options_text, header, copy_options_text
stage,
table_name,
file_type,
format_options_text,
header,
copy_options_text,
)
self.execute(sql, commit=True)
except Exception as e:
Expand Down Expand Up @@ -422,7 +443,7 @@ def insert_dataframe_to_table(
if create:
if not metadata:
logger.info("Metadata is missing. Generating metadata ...")
metadata = find_column_type(dataframe)
metadata = find_column_type(dataframe, "snowflake")
logger.info("Metadata is complete. Creating new table ...")

create_join = (
Expand Down
29 changes: 19 additions & 10 deletions locopy/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,8 @@

import yaml

from .errors import (
CompressionError,
CredentialsError,
LocopyConcatError,
LocopyIgnoreHeaderError,
LocopySplitError,
)
from .errors import (CompressionError, CredentialsError, LocopyConcatError,
LocopyIgnoreHeaderError, LocopySplitError)
from .logger import INFO, get_logger

logger = get_logger(__name__, INFO)
Expand Down Expand Up @@ -250,7 +245,7 @@ def read_config_yaml(config_yaml):


# make it more granular, eg. include length
def find_column_type(dataframe):
def find_column_type(dataframe, warehouse_type: str):
"""
Find data type of each column from the dataframe.
Expand All @@ -271,6 +266,9 @@ def find_column_type(dataframe):
----------
dataframe : Pandas dataframe
warehouse_type: str
Required to properly determine format of uploaded data, either "snowflake" or "redshift".
Returns
-------
dict
Expand All @@ -284,10 +282,16 @@ def find_column_type(dataframe):
def validate_date_object(column):
try:
pd.to_datetime(column)
if re.search(r"\d+:\d+:\d+", column.sample(1).to_string(index=False)):
sample_data = column.sample(1).to_string(index=False)
if re.search(r"\d+:\d+:\d+", sample_data):
return "timestamp"
else:
elif warehouse_type == "redshift" or re.search(
r"(\d{4}-\d{2}-\d{2})|(\d{2}-[A-Z]{3}-\d{4})|(\d{2}/\d{2}/\d{4})",
sample_data,
):
return "date"
else:
return "varchar"
except (ValueError, TypeError):
return None

Expand All @@ -298,6 +302,11 @@ def validate_float_object(column):
except (ValueError, TypeError):
return None

if warehouse_type.lower() not in ["snowflake", "redshift"]:
raise ValueError(
'warehouse_type argument must be either "snowflake" or "redshift"'
)

column_type = []
for column in dataframe.columns:
logger.debug("Checking column: %s", column)
Expand Down
26 changes: 4 additions & 22 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@ authors = [
{ name="Faisal Dosani", email="faisal.dosani@capitalone.com" },
]
license = {text = "Apache Software License"}
dependencies = [
"boto3<=1.28.3,>=1.9.92",
"PyYAML<=6.0,>=5.1",
"pandas<=2.0.3,>=0.25.2",
"numpy<=1.25.1,>=1.22.0",
]
dependencies = ["boto3<=1.28.39,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.1.0,>=0.25.2", "numpy<=1.25.2,>=1.22.0"]

requires-python = ">=3.8.0"
classifiers = [
Expand Down Expand Up @@ -45,24 +40,11 @@ psycopg2 = ["psycopg2-binary>=2.7.7"]
pg8000 = ["pg8000>=1.13.1"]
snowflake = ["snowflake-connector-python[pandas]>=2.1.2"]
docs = ["sphinx", "sphinx_rtd_theme"]
tests = [
"hypothesis",
"pytest",
"pytest-cov",
]
qa = [
"pre-commit",
"black",
"isort",
]
tests = ["hypothesis", "pytest", "pytest-cov"]
qa = ["pre-commit", "black", "isort"]
build = ["build", "twine", "wheel"]
edgetest = ["edgetest", "edgetest-conda"]
dev = [
"locopy[tests]",
"locopy[docs]",
"locopy[qa]",
"locopy[build]",
]
dev = ["locopy[tests]", "locopy[docs]", "locopy[qa]", "locopy[build]"]

[isort]
multi_line_output = 3
Expand Down
Loading

0 comments on commit a1a2aba

Please sign in to comment.