Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] can not pickle models with TrappingSR3 #585

Open
cmower opened this issue Dec 19, 2024 · 1 comment · May be fixed by #586
Open

[BUG] can not pickle models with TrappingSR3 #585

cmower opened this issue Dec 19, 2024 · 1 comment · May be fixed by #586

Comments

@cmower
Copy link

cmower commented Dec 19, 2024

Unable to pickle objects created using SINDy(optimizer=ps.TrappingSR3()). My current work around is to use dill.

Reproducing code example:

import pickle
import numpy as np
import pysindy as ps
from pysindy import SINDy
from pysindy.utils import lorenz
from scipy.integrate import solve_ivp

dt = 0.002
t_train = np.arange(0, 10, dt)
x0_train = [-8, 8, 27]
t_train_span = (t_train[0], t_train[-1])
x_train = solve_ivp(lorenz, t_train_span, x0_train, t_eval=t_train).y.T

model = SINDy(
    optimizer=ps.TrappingSR3(),  # comment this line and the script succeeds, uncomment and failure
)
model.fit(x_train, t=dt)
model.print()

with open("model.pkl", "wb") as f:
    pickle.dump(model, f)

del model

with open("model.pkl", "rb") as f:
    model = pickle.load(f)

model.print()

Error message:

Traceback (most recent call last):
  File "example_fail.py", line 25, in <module>
    pickle.dump(model, f)
AttributeError: Can't pickle local object 'get_regularization.<locals>.<lambda>'

PySINDy/Python version information:

1.7.5 3.8.20 | packaged by conda-forge | (default, Sep 30 2024, 17:52:49) 
[GCC 13.3.0]
@Jacob-Stevens-Haas
Copy link
Member

Jacob-Stevens-Haas commented Dec 20, 2024

Thanks for the bug report, @cmower!

Fortunately there's a test already test_optimizers.test_pickle, so this is easy to reproduce. Adding SR3 as an optimizer to the test parametrization produces a similar error message (because some of the relevant lines have been changed from your version).

Context
get_regularization is a way for SR3 to lookup a regularization function from a string name. There's a similar function get_prox. The function they return get sassigned to self.reg and self.prox respectively.

Additional troubleshooting notes:

  • We have a decorator _validate_prox_and_reg_inputs to add guard code, so we need to use @functools.wraps in order to pickle (easy, done)
  • What we're wrapping needs to be a name available at module-level (e.g. regularization_l2) (easy, done)
  • Because we're decorating dynamically, the _validate...()version and the actual module level names have different definitions, raising a PicklingError. (hard)
  • We currently have to delay decoration since (a) we need to know whether we're using a weighted regularizer to decorate, but (b) the guard code we're adding in the decorator doesn't run until later.

The solution is to either apply to guards within SR3 where they call self.prox and self.reg or decorate at the source. If decorated at the source, we need to remove the "weighted" version of strings and merely use the type of the weight to infer whether the user wants weighted regularizers.

Heads up @himkwtn this is an issue that existed before you, but since it involves code you wrote in #544 and #548 it might be interesting. I'll CC you on the PR if you want to see the fix.

@Jacob-Stevens-Haas Jacob-Stevens-Haas linked a pull request Dec 20, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants