From a20439e3e579418fa3c38375a55cc54319f4a2a1 Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Thu, 8 Apr 2021 04:36:56 -0700 Subject: [PATCH] Inline `Mapping[str, Mapping[str, T]]` in filtering.py. T now refers to the leaf type of the two level mapping. PiperOrigin-RevId: 367404372 Change-Id: Icf9c65f804206d0e7938396a88036285774259b3 --- haiku/_src/filtering.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/haiku/_src/filtering.py b/haiku/_src/filtering.py index f320431ce..91460024a 100644 --- a/haiku/_src/filtering.py +++ b/haiku/_src/filtering.py @@ -18,15 +18,16 @@ from typing import Any, Callable, Generator, Mapping, Tuple, TypeVar from haiku._src import data_structures -from haiku._src.typing import Params, State # pylint: disable=g-multiple-import import jax.numpy as jnp -T = TypeVar("T", Params, State) +T = TypeVar("T") InT = TypeVar("InT") OutT = TypeVar("OutT") -def traverse(structure: T) -> Generator[Tuple[str, str, Any], None, None]: +def traverse( + structure: Mapping[str, Mapping[str, T]], +) -> Generator[Tuple[str, str, T], None, None]: """Iterates over a structure yielding module names, names and values. NOTE: Items are iterated in key sorted order. @@ -46,8 +47,8 @@ def traverse(structure: T) -> Generator[Tuple[str, str, Any], None, None]: def partition( predicate: Callable[[str, str, jnp.ndarray], bool], - structure: T, -) -> Tuple[T, T]: + structure: Mapping[str, Mapping[str, T]], +) -> Tuple[Mapping[str, Mapping[str, T]], Mapping[str, Mapping[str, T]]]: """Partitions the input structure in two according to a given predicate. For a given set of parameters, you can use :func:`partition` to split them: @@ -79,10 +80,10 @@ def partition( def partition_n( - fn: Callable[[str, str, jnp.ndarray], int], - structure: T, + fn: Callable[[str, str, T], int], + structure: Mapping[str, Mapping[str, T]], n: int, -) -> Tuple[T, ...]: +) -> Tuple[Mapping[str, Mapping[str, T]], ...]: """Partitions a structure into `n` structures. For a given set of parameters, you can use :func:`partition_n` to split them @@ -121,9 +122,9 @@ def partition_n( def filter( # pylint: disable=redefined-builtin - predicate: Callable[[str, str, jnp.ndarray], bool], - structure: T, -) -> T: + predicate: Callable[[str, str, T], bool], + structure: Mapping[str, Mapping[str, T]], +) -> Mapping[str, Mapping[str, T]]: """Filters an input structure according to a user specified predicate. >>> params = {'linear': {'w': None, 'b': None}}