Skip to content

Commit

Permalink
Inline Mapping[str, Mapping[str, T]] in filtering.py.
Browse files Browse the repository at this point in the history
T now refers to the leaf type of the two level mapping.

PiperOrigin-RevId: 367404372
Change-Id: Icf9c65f804206d0e7938396a88036285774259b3
  • Loading branch information
tomhennigan authored and copybara-github committed Apr 8, 2021
1 parent c097dd3 commit a20439e
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions haiku/_src/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
from typing import Any, Callable, Generator, Mapping, Tuple, TypeVar

from haiku._src import data_structures
from haiku._src.typing import Params, State # pylint: disable=g-multiple-import
import jax.numpy as jnp

T = TypeVar("T", Params, State)
T = TypeVar("T")
InT = TypeVar("InT")
OutT = TypeVar("OutT")


def traverse(structure: T) -> Generator[Tuple[str, str, Any], None, None]:
def traverse(
structure: Mapping[str, Mapping[str, T]],
) -> Generator[Tuple[str, str, T], None, None]:
"""Iterates over a structure yielding module names, names and values.
NOTE: Items are iterated in key sorted order.
Expand All @@ -46,8 +47,8 @@ def traverse(structure: T) -> Generator[Tuple[str, str, Any], None, None]:

def partition(
predicate: Callable[[str, str, jnp.ndarray], bool],
structure: T,
) -> Tuple[T, T]:
structure: Mapping[str, Mapping[str, T]],
) -> Tuple[Mapping[str, Mapping[str, T]], Mapping[str, Mapping[str, T]]]:
"""Partitions the input structure in two according to a given predicate.
For a given set of parameters, you can use :func:`partition` to split them:
Expand Down Expand Up @@ -79,10 +80,10 @@ def partition(


def partition_n(
fn: Callable[[str, str, jnp.ndarray], int],
structure: T,
fn: Callable[[str, str, T], int],
structure: Mapping[str, Mapping[str, T]],
n: int,
) -> Tuple[T, ...]:
) -> Tuple[Mapping[str, Mapping[str, T]], ...]:
"""Partitions a structure into `n` structures.
For a given set of parameters, you can use :func:`partition_n` to split them
Expand Down Expand Up @@ -121,9 +122,9 @@ def partition_n(


def filter( # pylint: disable=redefined-builtin
predicate: Callable[[str, str, jnp.ndarray], bool],
structure: T,
) -> T:
predicate: Callable[[str, str, T], bool],
structure: Mapping[str, Mapping[str, T]],
) -> Mapping[str, Mapping[str, T]]:
"""Filters an input structure according to a user specified predicate.
>>> params = {'linear': {'w': None, 'b': None}}
Expand Down

0 comments on commit a20439e

Please sign in to comment.