diff --git a/.github/workflows/edgetest.yml b/.github/workflows/edgetest.yml index a690653..b8b4d94 100644 --- a/.github/workflows/edgetest.yml +++ b/.github/workflows/edgetest.yml @@ -5,10 +5,14 @@ name: Run edgetest on: schedule: - cron: '35 17 * * 5' + workflow_dispatch: jobs: edgetest: runs-on: ubuntu-latest name: running edgetest + permissions: + contents: write + pull-requests: write steps: - uses: actions/checkout@v4 with: @@ -19,7 +23,7 @@ jobs: cp tests/data/.locopyrc ~/.locopyrc cp tests/data/.locopy-sfrc ~/.locopy-sfrc - id: run-edgetest - uses: fdosani/run-edgetest-action@v1.3 + uses: edgetest-dev/run-edgetest-action@v1.5 with: edgetest-flags: '-c pyproject.toml --export' base-branch: 'develop' diff --git a/locopy/_version.py b/locopy/_version.py index dc3ce4d..9a5b5d9 100644 --- a/locopy/_version.py +++ b/locopy/_version.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.6.2" +__version__ = "0.6.3" diff --git a/locopy/utility.py b/locopy/utility.py index 7b61f31..22cd6c0 100644 --- a/locopy/utility.py +++ b/locopy/utility.py @@ -30,6 +30,7 @@ import pandas as pd import polars as pl +import pyarrow as pa import yaml from locopy.errors import ( @@ -317,6 +318,20 @@ def validate_float_object(column): except (ValueError, TypeError): return None + def check_column_type_pyarrow(pa_dtype): + if pa.types.is_temporal(pa_dtype): + return "timestamp" + elif pa.types.is_boolean(pa_dtype): + return "boolean" + elif pa.types.is_integer(pa_dtype): + return "int" + elif pa.types.is_floating(pa_dtype): + return "float" + elif pa.types.is_string(pa_dtype): + return "varchar" + else: + return "varchar" + if warehouse_type.lower() not in ["snowflake", "redshift"]: raise ValueError( 'warehouse_type argument must be either "snowflake" or "redshift"' @@ -328,24 +343,28 @@ def validate_float_object(column): data = dataframe[column].dropna().reset_index(drop=True) if data.size == 0: column_type.append("varchar") - elif (data.dtype in ["datetime64[ns]", "M8[ns]"]) or ( - re.match(r"(datetime64\[ns\,\W)([a-zA-Z]+)(\])", str(data.dtype)) - ): - column_type.append("timestamp") - elif str(data.dtype).lower().startswith("bool"): - column_type.append("boolean") - elif str(data.dtype).startswith("object"): - data_type = validate_float_object(data) or validate_date_object(data) - if not data_type: - column_type.append("varchar") - else: - column_type.append(data_type) - elif str(data.dtype).lower().startswith("int"): - column_type.append("int") - elif str(data.dtype).lower().startswith("float"): - column_type.append("float") + elif isinstance(data.dtype, pd.ArrowDtype): + datatype = check_column_type_pyarrow(data.dtype.pyarrow_dtype) + column_type.append(datatype) else: - column_type.append("varchar") + if (data.dtype in ["datetime64[ns]", "M8[ns]"]) or ( + re.match(r"(datetime64\[ns\,\W)([a-zA-Z]+)(\])", str(data.dtype)) + ): + column_type.append("timestamp") + elif str(data.dtype).lower().startswith("bool"): + column_type.append("boolean") + elif str(data.dtype).startswith("object"): + data_type = validate_float_object(data) or validate_date_object(data) + if not data_type: + column_type.append("varchar") + else: + column_type.append(data_type) + elif str(data.dtype).lower().startswith("int"): + column_type.append("int") + elif str(data.dtype).lower().startswith("float"): + column_type.append("float") + else: + column_type.append("varchar") logger.info("Parsing column %s to %s", column, column_type[-1]) return OrderedDict(zip(list(dataframe.columns), column_type)) diff --git a/pyproject.toml b/pyproject.toml index ab77965..1f4065c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = [ { name="Faisal Dosani", email="faisal.dosani@capitalone.com" }, ] license = {text = "Apache Software License"} -dependencies = ["boto3<=1.35.43,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.3,>=0.25.2", "numpy<=2.0.2,>=1.22.0", "polars>=0.20.0"] +dependencies = ["boto3<=1.35.80,>=1.9.92", "PyYAML<=6.0.2,>=5.1", "pandas<=2.2.3,>=1.5.0", "numpy<=2.2.0,>=1.22.0", "polars>=0.20.0", "pyarrow>=10.0.1"] requires-python = ">=3.9.0" classifiers = [ @@ -104,7 +104,7 @@ ban-relative-imports = "all" convention = "numpy" [edgetest.envs.core] -python_version = "3.9" +python_version = "3.10" extras = [ "tests", "psycopg2", diff --git a/requirements.txt b/requirements.txt index 3a26156..7e9be3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,8 @@ -# -# This file is autogenerated by pip-compile with Python 3.11 -# by the following command: -# -# pip-compile --output-file=requirements.txt pyproject.toml -# - -boto3==1.34.126 +# This file was autogenerated by uv via the following command: +# uv pip compile pyproject.toml -o requirements.txt +boto3==1.35.80 # via locopy (pyproject.toml) -botocore==1.34.130 +botocore==1.35.80 # via # boto3 # s3transfer @@ -15,27 +10,29 @@ jmespath==1.0.1 # via # boto3 # botocore -numpy==1.26.4 +numpy==2.2.0 # via # locopy (pyproject.toml) # pandas -pandas==2.2.2 +pandas==2.2.3 # via locopy (pyproject.toml) -polars==1.6.0 +polars==1.17.1 + # via locopy (pyproject.toml) +pyarrow==18.1.0 # via locopy (pyproject.toml) python-dateutil==2.9.0.post0 # via # botocore # pandas -pytz==2024.1 +pytz==2024.2 # via pandas -pyyaml==6.0.1 +pyyaml==6.0.2 # via locopy (pyproject.toml) -s3transfer==0.10.1 +s3transfer==0.10.4 # via boto3 -six==1.16.0 +six==1.17.0 # via python-dateutil -tzdata==2024.1 +tzdata==2024.2 # via pandas -urllib3==1.26.20 +urllib3==2.2.3 # via botocore diff --git a/tests/test_utility.py b/tests/test_utility.py index ec6f8f2..2ccf772 100644 --- a/tests/test_utility.py +++ b/tests/test_utility.py @@ -29,6 +29,7 @@ from unittest import mock import locopy.utility as util +import pyarrow as pa import pytest from locopy.errors import ( CompressionError, @@ -388,7 +389,48 @@ def test_find_column_type_new(): "d": "varchar", "e": "boolean", } + assert find_column_type(input_text, "snowflake") == output_text_snowflake + assert find_column_type(input_text, "redshift") == output_text_redshift + + +def test_find_column_type_pyarrow(): + import pandas as pd + + input_text = pd.DataFrame.from_dict( + { + "a": [1], + "b": [pd.Timestamp("2017-01-01T12+0")], + "c": [1.2], + "d": ["a"], + "e": [True], + } + ) + input_text = input_text.astype( + dtype={ + "a": "int64[pyarrow]", + "b": pd.ArrowDtype(pa.timestamp("ns", tz="UTC")), + "c": "float64[pyarrow]", + "d": pd.ArrowDtype(pa.string()), + "e": "bool[pyarrow]", + } + ) + + output_text_snowflake = { + "a": "int", + "b": "timestamp", + "c": "float", + "d": "varchar", + "e": "boolean", + } + + output_text_redshift = { + "a": "int", + "b": "timestamp", + "c": "float", + "d": "varchar", + "e": "boolean", + } assert find_column_type(input_text, "snowflake") == output_text_snowflake assert find_column_type(input_text, "redshift") == output_text_redshift