From cb3d1c5813ba5fa94ef46b14e9993ae0b42b08b2 Mon Sep 17 00:00:00 2001 From: b8raoult <53792887+b8raoult@users.noreply.github.com> Date: Tue, 14 Mar 2023 07:41:15 +0000 Subject: [PATCH] apply new version of black --- ecmwf/opendata/__init__.py | 2 +- ecmwf/opendata/client.py | 29 ++++++++++++++++++++--------- ecmwf/opendata/date.py | 2 -- ecmwf/opendata/grib.py | 1 - setup.py | 2 +- tests/test_date.py | 1 - tests/test_examples.py | 1 - tests/test_stream.py | 1 - tools/check-index.py | 1 - tools/param-units.py | 2 -- 10 files changed, 22 insertions(+), 20 deletions(-) diff --git a/ecmwf/opendata/__init__.py b/ecmwf/opendata/__init__.py index 37a9c4e..838ee82 100644 --- a/ecmwf/opendata/__init__.py +++ b/ecmwf/opendata/__init__.py @@ -10,6 +10,6 @@ from .client import Client -__version__ = "0.1.2" +__version__ = "0.1.3" __all__ = ["Client"] diff --git a/ecmwf/opendata/client.py b/ecmwf/opendata/client.py index d6f8100..f4c81c1 100644 --- a/ecmwf/opendata/client.py +++ b/ecmwf/opendata/client.py @@ -47,7 +47,6 @@ def warning_once(*args, did_you_mean=None): - if repr(args) in ONCE: return @@ -109,21 +108,22 @@ def __init__( preserve_request_order=False, infer_stream_keyword=True, debug=False, + verify=True, ): self._url = None self.source = source self.beta = beta self.preserve_request_order = preserve_request_order self.infer_stream_keyword = infer_stream_keyword + self.session = requests.Session() + self.verify = verify if debug: logging.basicConfig(level=logging.DEBUG) @property def url(self): - if self._url is None: - if self.source.startswith("http://") or self.source.startswith("https://"): self._url = self.source else: @@ -141,12 +141,22 @@ def url(self): def retrieve(self, request=None, target=None, **kwargs): result = self._get_urls(request, target=target, use_index=True, **kwargs) - result.size = download(result.urls, target=result.target) + result.size = download( + result.urls, + target=result.target, + verify=self.verify, + session=self.session, + ) return result def download(self, request=None, target=None, **kwargs): result = self._get_urls(request, target=target, use_index=False, **kwargs) - result.size = download(result.urls, target=result.target) + result.size = download( + result.urls, + target=result.target, + verify=self.verify, + session=self.session, + ) return result def latest(self, request=None, **kwargs): @@ -171,7 +181,10 @@ def latest(self, request=None, **kwargs): date=date, **params, ) - codes = [robust(requests.head)(url).status_code for url in result.urls] + codes = [ + robust(self.session.head)(url, verify=self.verify).status_code + for url in result.urls + ] if len(codes) > 0 and all(c == 200 for c in codes): return date date -= delta @@ -179,7 +192,6 @@ def latest(self, request=None, **kwargs): raise ValueError("Cannot establish latest date for %r" % (result.for_urls,)) def _get_urls(self, request=None, use_index=None, target=None, **kwargs): - assert use_index in (True, False) if request is None: params = dict(**kwargs) @@ -230,7 +242,6 @@ def _get_urls(self, request=None, use_index=None, target=None, **kwargs): ) def get_parts(self, data_urls, for_index): - count = len(for_index) result = [] line = None @@ -240,7 +251,7 @@ def get_parts(self, data_urls, for_index): for url in data_urls: base, _ = os.path.splitext(url) index_url = f"{base}.index" - r = robust(requests.get)(index_url) + r = robust(self.session.get)(index_url, verify=self.verify) r.raise_for_status() parts = [] diff --git a/ecmwf/opendata/date.py b/ecmwf/opendata/date.py index 4f8b287..3bbb035 100644 --- a/ecmwf/opendata/date.py +++ b/ecmwf/opendata/date.py @@ -31,7 +31,6 @@ def canonical_time(time): def full_date(date, time=None): - if isinstance(date, datetime.date) and not isinstance(date, datetime.datetime): date = datetime.datetime(date.year, date.month, date.day) @@ -43,7 +42,6 @@ def full_date(date, time=None): date = datetime.datetime(date // 10000, date % 10000 // 100, date % 100) if isinstance(date, str): - try: return full_date(int(date), time) except ValueError: diff --git a/ecmwf/opendata/grib.py b/ecmwf/opendata/grib.py index 0ef8ce8..80209d9 100644 --- a/ecmwf/opendata/grib.py +++ b/ecmwf/opendata/grib.py @@ -25,7 +25,6 @@ def grib_index(path): - index = [] with open(path, "rb") as f: while True: diff --git a/setup.py b/setup.py index e9f0758..d4a65b6 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def read(fname): url="https://github.com/ecmwf/ecmwf-opendata", packages=setuptools.find_namespace_packages(include=["ecmwf.*"]), include_package_data=True, - install_requires=["multiurl>=0.2.0"], + install_requires=["multiurl>=0.2.1"], zip_safe=True, keywords="tool", classifiers=[ diff --git a/tests/test_date.py b/tests/test_date.py index 6a73e68..d0e7b89 100644 --- a/tests/test_date.py +++ b/tests/test_date.py @@ -7,7 +7,6 @@ @freeze_time("2022-01-21T13:21:34Z") def test_date_1(): - assert full_date("20010101") == datetime.datetime(2001, 1, 1) assert full_date(20010101) == datetime.datetime(2001, 1, 1) assert full_date("2001-01-01") == datetime.datetime(2001, 1, 1) diff --git a/tests/test_examples.py b/tests/test_examples.py index 5704e0f..057df9c 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -8,7 +8,6 @@ def example_list(): - examples = [] code = [] python = False diff --git a/tests/test_stream.py b/tests/test_stream.py index db7d670..14abc8b 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -8,7 +8,6 @@ def patch_stream(stream, time, type): def test_stream(): - assert patch_stream("oper", 0, "fc") == "oper" assert patch_stream("oper", 6, "fc") == "scda" assert patch_stream("oper", 12, "fc") == "oper" diff --git a/tools/check-index.py b/tools/check-index.py index 98400c7..991d00c 100755 --- a/tools/check-index.py +++ b/tools/check-index.py @@ -87,7 +87,6 @@ ("stream", rstream, stream), ("type", rtype, type), ): - if b != c: print("Mismatch: %r %r %r" % (a, b, c)) # assert False diff --git a/tools/param-units.py b/tools/param-units.py index f18c0f5..6ab79d8 100755 --- a/tools/param-units.py +++ b/tools/param-units.py @@ -45,7 +45,6 @@ with open("index.txt") as f: for j, url in enumerate(f): - url = url.rstrip() if ( @@ -67,7 +66,6 @@ lines.append(line) for i, line in enumerate(lines): - key = tuple(line.get(x) for x in ("type", "stream", "levtype", "param")) if key in seen: continue