From 10df640eb06191113db19e0a58720cf8311bcfa8 Mon Sep 17 00:00:00 2001 From: Adam Cogdell Date: Fri, 20 Dec 2024 09:44:51 -0800 Subject: [PATCH] Save `StepMetadata` in the `AsyncCheckpointer`. PiperOrigin-RevId: 708348968 --- t5x/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/t5x/utils.py b/t5x/utils.py index 36ead8d0a..67823c027 100644 --- a/t5x/utils.py +++ b/t5x/utils.py @@ -256,6 +256,7 @@ def save( path: str, item: train_state_lib.TrainState, force: bool = False, + custom: dict[str, Any] | None = None, state_transformation_fns: Sequence[ checkpoints.SaveStateTransformationFn ] = (), @@ -268,6 +269,7 @@ def save( path: path to save item to. item: a TrainState PyTree to save. force: unused. + custom: unused. state_transformation_fns: Transformations to apply, in order, to the state before writing. concurrent_gb: the approximate number of gigabytes of partitionable