Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Elastic Net Regularizer #49

Merged
merged 10 commits into from
May 4, 2017
6 changes: 3 additions & 3 deletions dask_glm/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from dask_glm.utils import dot, exp, log1p
from dask_glm.families import Logistic
from dask_glm.regularizers import _regularizers
from dask_glm.regularizers import Regularizer


def compute_stepsize_dask(beta, step, Xbeta, Xstep, y, curr_val,
Expand Down Expand Up @@ -155,7 +155,7 @@ def admm(X, y, regularizer='l1', lamduh=0.1, rho=1, over_relax=1,

pointwise_loss = family.pointwise_loss
pointwise_gradient = family.pointwise_gradient
regularizer = _regularizers.get(regularizer, regularizer) # string
regularizer = Regularizer.get(regularizer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll be sad to see this line go, but 👍


def create_local_gradient(func):
@functools.wraps(func)
Expand Down Expand Up @@ -331,7 +331,7 @@ def proximal_grad(X, y, regularizer='l1', lamduh=0.1, family=Logistic,
recalcRate = 10
backtrackMult = firstBacktrackMult
beta = np.zeros(p)
regularizer = _regularizers.get(regularizer, regularizer) # string
regularizer = Regularizer.get(regularizer)

if verbose:
print('# -f |df/f| |dx/x| step')
Expand Down
19 changes: 12 additions & 7 deletions dask_glm/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Regularizer(object):
Defines the set of methods required to create a new regularization object. This includes
the regularization functions itself and it's gradient, hessian, and proximal operator.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's -> its

"""
_name = '_base'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe this should be just name. I think @moody-marlin mentioned that users might defining their own regularizers, so they shouldn't need to override private attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, made the change.


def f(self, beta):
"""Regularization function."""
Expand Down Expand Up @@ -44,9 +45,18 @@ def wrapped(beta, *args):
return hess(beta, *args) + lam * self.hessian(beta)
return wrapped

@classmethod
def get(cls, obj):
if isinstance(obj, cls):
return obj
elif isinstance(obj, str):
return {o._name: o for o in cls.__subclasses__()}[obj]()
raise TypeError('Not a valid regularizer object.')


class L2(Regularizer):
"""L2 regularization."""
_name = 'l2'

def f(self, beta):
return (beta**2).sum() / 2
Expand All @@ -63,6 +73,7 @@ def proximal_operator(self, beta, t):

class L1(Regularizer):
"""L1 regularization."""
_name = 'l1'

def f(self, beta):
return (np.abs(beta)).sum()
Expand All @@ -83,6 +94,7 @@ def proximal_operator(self, beta, t):

class ElasticNet(Regularizer):
"""Elastic net regularization."""
_name = 'elastic_net'

def __init__(self, weight=0.5):
self.weight = weight
Expand Down Expand Up @@ -111,10 +123,3 @@ def func(b):
return 0
return (b - g * np.sign(b)) / (t - g + 1)
return beta


_regularizers = {
'l1': L1(),
'l2': L2(),
'elastic_net': ElasticNet()
}
4 changes: 2 additions & 2 deletions dask_glm/tests/test_algos_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dask_glm.algorithms import (newton, bfgs, proximal_grad,
gradient_descent, admm)
from dask_glm.families import Logistic, Normal, Poisson
from dask_glm.regularizers import _regularizers
from dask_glm.regularizers import Regularizer
from dask_glm.utils import sigmoid, make_y


Expand Down Expand Up @@ -89,7 +89,7 @@ def test_basic_unreg_descent(func, kwargs, N, nchunks, family):
@pytest.mark.parametrize('nchunks', [1, 10])
@pytest.mark.parametrize('family', [Logistic, Normal, Poisson])
@pytest.mark.parametrize('lam', [0.01, 1.2, 4.05])
@pytest.mark.parametrize('reg', list(_regularizers.values()))
@pytest.mark.parametrize('reg', [r() for r in Regularizer.__subclasses__()])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

def test_basic_reg_descent(func, kwargs, N, nchunks, family, lam, reg):
beta = np.random.normal(size=2)
M = len(beta)
Expand Down
6 changes: 3 additions & 3 deletions dask_glm/tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from dask_glm.estimators import LogisticRegression, LinearRegression, PoissonRegression
from dask_glm.datasets import make_classification, make_regression, make_poisson
from dask_glm.algorithms import _solvers
from dask_glm.regularizers import _regularizers
from dask_glm.regularizers import Regularizer


@pytest.fixture(params=_solvers.keys())
@pytest.fixture(params=[r() for r in Regularizer.__subclasses__()])
def solver(request):
"""Parametrized fixture for all the solver names"""
return request.param


@pytest.fixture(params=_regularizers.keys())
@pytest.fixture(params=[r() for r in Regularizer.__subclasses__()])
def regularizer(request):
"""Parametrized fixture for all the regularizer names"""
return request.param
Expand Down