Skip to content

Commit

Permalink
Add native support for expressions via filters (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
J535D165 authored Dec 22, 2024
1 parent c8d00f8 commit 804c741
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 25 deletions.
99 changes: 85 additions & 14 deletions pyalex/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,32 @@ def __setattr__(self, key, value):
)


class or_(dict):
pass


class _LogicalExpression:
token = None

def __init__(self, value):
self.value = value

def __str__(self) -> str:
return f"{self.token}{self.value}"


class not_(_LogicalExpression):
token = "!"


class gt_(_LogicalExpression):
token = ">"


class lt_(_LogicalExpression):
token = "<"


def _quote_oa_value(v):
"""Prepare a value for the OpenAlex API.
Expand All @@ -41,30 +67,40 @@ def _quote_oa_value(v):
if isinstance(v, bool):
return str(v).lower()

if isinstance(v, _LogicalExpression) and isinstance(v.value, str):
v.value = quote_plus(v.value)
return v

if isinstance(v, str):
return quote_plus(v)

return v


def _flatten_kv(d, prefix=""):
def _flatten_kv(d, prefix=None, logical="+"):
if prefix is None and not isinstance(d, dict):
raise ValueError("prefix should be set if d is not a dict")

if isinstance(d, dict):
logical_subd = "|" if isinstance(d, or_) else logical

t = []
for k, v in d.items():
if isinstance(v, list):
t.extend([f"{prefix}.{k}:{_quote_oa_value(i)}" for i in v])
else:
new_prefix = f"{prefix}.{k}" if prefix else f"{k}"
x = _flatten_kv(v, prefix=new_prefix)
t.append(x)
x = _flatten_kv(
v, prefix=f"{prefix}.{k}" if prefix else f"{k}", logical=logical_subd
)
t.append(x)

return ",".join(t)
elif isinstance(d, list):
list_str = logical.join([f"{_quote_oa_value(i)}" for i in d])
return f"{prefix}:{list_str}"
else:
return f"{prefix}:{_quote_oa_value(d)}"


def _params_merge(params, add_params):
for k, _v in add_params.items():
for k in add_params.keys():
if (
k in params
and isinstance(params[k], dict)
Expand Down Expand Up @@ -113,6 +149,18 @@ def invert_abstract(inv_index):
return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1])))


def _wrap_values_nested_dict(d, func):
for k, v in d.items():
if isinstance(v, dict):
d[k] = _wrap_values_nested_dict(v, func)
elif isinstance(v, list):
d[k] = [func(i) for i in v]
else:
d[k] = func(v)

return d


class QueryError(ValueError):
pass

Expand Down Expand Up @@ -207,9 +255,6 @@ class BaseOpenAlex:
def __init__(self, params=None):
self.params = params

def _get_multi_items(self, record_list):
return self.filter(openalex_id="|".join(record_list)).get()

def _full_collection_name(self):
if self.params is not None and "q" in self.params.keys():
return (
Expand All @@ -234,10 +279,14 @@ def __getattr__(self, key):

def __getitem__(self, record_id):
if isinstance(record_id, list):
return self._get_multi_items(record_id)
if len(record_id) > 100:
raise ValueError("OpenAlex does not support more than 100 ids")

return self.filter_or(openalex_id=record_id).get(per_page=len(record_id))

return self._get_from_url(
f"{self._full_collection_name()}/{record_id}", return_meta=False
f"{self._full_collection_name()}/{_quote_oa_value(record_id)}",
return_meta=False,
)

@property
Expand Down Expand Up @@ -322,7 +371,10 @@ def paginate(self, method="cursor", page=1, per_page=None, cursor="*", n_max=100
def random(self):
return self.__getitem__("random")

def _add_params(self, argument, new_params):
def _add_params(self, argument, new_params, raise_if_exists=False):
if raise_if_exists:
raise NotImplementedError("raise_if_exists is not implemented")

if self.params is None:
self.params = {argument: new_params}
elif argument in self.params and isinstance(self.params[argument], dict):
Expand All @@ -336,6 +388,25 @@ def filter(self, **kwargs):
self._add_params("filter", kwargs)
return self

def filter_and(self, **kwargs):
return self.filter(**kwargs)

def filter_or(self, **kwargs):
self._add_params("filter", or_(kwargs), raise_if_exists=False)
return self

def filter_not(self, **kwargs):
self._add_params("filter", _wrap_values_nested_dict(kwargs, not_))
return self

def filter_gt(self, **kwargs):
self._add_params("filter", _wrap_values_nested_dict(kwargs, gt_))
return self

def filter_lt(self, **kwargs):
self._add_params("filter", _wrap_values_nested_dict(kwargs, lt_))
return self

def search_filter(self, **kwargs):
self._add_params("filter", {f"{k}.search": v for k, v in kwargs.items()})
return self
Expand Down
73 changes: 62 additions & 11 deletions tests/test_pyalex.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@ def test_multi_works():
# the work to extract the referenced works of
w = Works()["W2741809807"]

assert len(Works()[w["referenced_works"]]) == 25
assert len(Works()[w["referenced_works"]]) >= 38

assert (
len(Works().filter_or(openalex_id=w["referenced_works"]).get(per_page=100))
>= 38
)


def test_works_multifilter():
Expand Down Expand Up @@ -278,33 +283,80 @@ def test_random_publishers():


def test_and_operator():
# https://github.com/J535D165/pyalex/issues/11
url = "https://api.openalex.org/works?filter=institutions.country_code:tw,institutions.country_code:hk,institutions.country_code:us,publication_year:2022"
urls = [
"https://api.openalex.org/works?filter=institutions.country_code:tw,institutions.country_code:hk,institutions.country_code:us,publication_year:2022",
"https://api.openalex.org/works?filter=institutions.country_code:tw+hk+us,publication_year:2022",
]

assert (
url
== Works()
Works()
.filter(
institutions={"country_code": ["tw", "hk", "us"]}, publication_year=2022
)
.url
in urls
)
assert (
url
== Works()
Works()
.filter(institutions={"country_code": "tw"})
.filter(institutions={"country_code": "hk"})
.filter(institutions={"country_code": "us"})
.filter(publication_year=2022)
.url
in urls
)
assert (
url
== Works()
Works()
.filter(institutions={"country_code": ["tw", "hk"]})
.filter(institutions={"country_code": "us"})
.filter(publication_year=2022)
.url
in urls
)


def test_or_operator():
assert (
Works()
.filter_or(
institutions={"country_code": ["tw", "hk", "us"]}, publication_year=2022
)
.url
== "https://api.openalex.org/works?filter=institutions.country_code:tw|hk|us,publication_year:2022"
)


def test_not_operator():
assert (
Works()
.filter_not(institutions={"country_code": "us"})
.filter(publication_year=2022)
.url
== "https://api.openalex.org/works?filter=institutions.country_code:!us,publication_year:2022"
)


def test_not_operator_list():
assert (
Works()
.filter_not(institutions={"country_code": ["tw", "hk", "us"]})
.filter(publication_year=2022)
.url
== "https://api.openalex.org/works?filter=institutions.country_code:!tw+!hk+!us,publication_year:2022"
)


@pytest.mark.skip("Wait for feedback on issue by OpenAlex")
def test_combined_operators():
# works:
# https://api.openalex.org/works?filter=publication_year:>2022,publication_year:!2023

# doesn't work
# https://api.openalex.org/works?filter=publication_year:>2022+!2023

assert (
Works().filter_gt(publication_year=2022).filter_not(publication_year=2023).url
== "https://api.openalex.org/works?filter=publication_year:>2022+!2023"
)


Expand Down Expand Up @@ -359,11 +411,10 @@ def test_filter_urlencoding():
)


@pytest.mark.skip("This test is not working due to inconsistencies in the API.")
def test_urlencoding_list():
assert (
Works()
.filter(
.filter_or(
doi=[
"https://doi.org/10.1207/s15327809jls0703&4_2",
"https://doi.org/10.1001/jama.264.8.944b",
Expand Down

0 comments on commit 804c741

Please sign in to comment.