Skip to content

Commit

Permalink
Elastic Net Regularizer (#49)
Browse files Browse the repository at this point in the history
* create regularizer base class, elastic net regularization, update readme with dev setup.

* vectorize proximal operator for elastic net

* fix l2, write tests for regularizers

* add tests for elastic net

* fix flake, change to string

* use base regularizer class to retrieve subclasses via string.

* add tests for get

* fix docstrings, add hessian for l1.
  • Loading branch information
postelrich authored and cicdw committed May 4, 2017
1 parent b251bfc commit b2b6f10
Show file tree
Hide file tree
Showing 8 changed files with 396 additions and 68 deletions.
14 changes: 14 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,19 @@ Generalized Linear Models in Dask

*This library is not ready for use.*

Developer Setup
---------------
Setup environment (from repo directory)::
conda create env
source activate dask_glm
pip install -e .

Run tests::

py.test



.. |Build Status| image:: https://travis-ci.org/dask/dask-glm.svg?branch=master
:target: https://travis-ci.org/dask/dask-glm
10 changes: 5 additions & 5 deletions dask_glm/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from dask_glm.utils import dot, exp, log1p
from dask_glm.families import Logistic
from dask_glm.regularizers import L1, _regularizers
from dask_glm.regularizers import Regularizer


def compute_stepsize_dask(beta, step, Xbeta, Xstep, y, curr_val,
Expand Down Expand Up @@ -150,12 +150,12 @@ def newton(X, y, max_iter=50, tol=1e-8, family=Logistic):
return beta


def admm(X, y, regularizer=L1, lamduh=0.1, rho=1, over_relax=1,
def admm(X, y, regularizer='l1', lamduh=0.1, rho=1, over_relax=1,
max_iter=250, abstol=1e-4, reltol=1e-2, family=Logistic):

pointwise_loss = family.pointwise_loss
pointwise_gradient = family.pointwise_gradient
regularizer = _regularizers.get(regularizer, regularizer) # string
regularizer = Regularizer.get(regularizer)

def create_local_gradient(func):
@functools.wraps(func)
Expand Down Expand Up @@ -319,7 +319,7 @@ def bfgs(X, y, max_iter=500, tol=1e-14, family=Logistic):
return beta


def proximal_grad(X, y, regularizer=L1, lamduh=0.1, family=Logistic,
def proximal_grad(X, y, regularizer='l1', lamduh=0.1, family=Logistic,
max_iter=100, tol=1e-8, verbose=False):

n, p = X.shape
Expand All @@ -331,7 +331,7 @@ def proximal_grad(X, y, regularizer=L1, lamduh=0.1, family=Logistic,
recalcRate = 10
backtrackMult = firstBacktrackMult
beta = np.zeros(p)
regularizer = _regularizers.get(regularizer, regularizer) # string
regularizer = Regularizer.get(regularizer)

if verbose:
print('# -f |df/f| |dx/x| step')
Expand Down
152 changes: 96 additions & 56 deletions dask_glm/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,85 +3,125 @@
import numpy as np


class L2(object):
class Regularizer(object):
"""Abstract base class for regularization object.
@staticmethod
def proximal_operator(beta, t):
return 1 / (1 + t) * beta
Defines the set of methods required to create a new regularization object. This includes
the regularization functions itself and its gradient, hessian, and proximal operator.
"""
name = '_base'

@staticmethod
def hessian(beta):
return 2 * np.eye(len(beta))
def f(self, beta):
"""Regularization function."""
raise NotImplementedError

@staticmethod
def add_reg_hessian(hess, lam):
def wrapped(beta, *args):
return hess(beta, *args) + lam * L2.hessian(beta)
return wrapped
def gradient(self, beta):
"""Gradient of regularization function."""
raise NotImplementedError

def hessian(self, beta):
"""Hessian of regularization function."""
raise NotImplementedError

@staticmethod
def f(beta):
return (beta**2).sum()
def proximal_operator(self, beta, t):
"""Proximal operator for regularization function."""
raise NotImplementedError

@staticmethod
def add_reg_f(f, lam):
def add_reg_f(self, f, lam):
"""Add regularization function to other function."""
def wrapped(beta, *args):
return f(beta, *args) + lam * L2.f(beta)
return f(beta, *args) + lam * self.f(beta)
return wrapped

@staticmethod
def gradient(beta):
return 2 * beta
def add_reg_grad(self, grad, lam):
"""Add regularization gradient to other gradient function."""
def wrapped(beta, *args):
return grad(beta, *args) + lam * self.gradient(beta)
return wrapped

@staticmethod
def add_reg_grad(grad, lam):
def add_reg_hessian(self, hess, lam):
"""Add regularization hessian to other hessian function."""
def wrapped(beta, *args):
return grad(beta, *args) + lam * L2.gradient(beta)
return hess(beta, *args) + lam * self.hessian(beta)
return wrapped

@classmethod
def get(cls, obj):
if isinstance(obj, cls):
return obj
elif isinstance(obj, str):
return {o.name: o for o in cls.__subclasses__()}[obj]()
raise TypeError('Not a valid regularizer object.')

class L1(object):

@staticmethod
def proximal_operator(beta, t):
z = np.maximum(0, beta - t) - np.maximum(0, -beta - t)
return z
class L2(Regularizer):
"""L2 regularization."""
name = 'l2'

@staticmethod
def hessian(beta):
raise ValueError('l1 norm is not twice differentiable!')
def f(self, beta):
return (beta**2).sum() / 2

@staticmethod
def add_reg_hessian(hess, lam):
def wrapped(beta, *args):
return hess(beta, *args) + lam * L1.hessian(beta)
return wrapped
def gradient(self, beta):
return beta

@staticmethod
def f(beta):
return (np.abs(beta)).sum()
def hessian(self, beta):
return np.eye(len(beta))

@staticmethod
def add_reg_f(f, lam):
def wrapped(beta, *args):
return f(beta, *args) + lam * L1.f(beta)
return wrapped
def proximal_operator(self, beta, t):
return 1 / (1 + t) * beta


class L1(Regularizer):
"""L1 regularization."""
name = 'l1'

@staticmethod
def gradient(beta):
def f(self, beta):
return (np.abs(beta)).sum()

def gradient(self, beta):
if np.any(np.isclose(beta, 0)):
raise ValueError('l1 norm is not differentiable at 0!')
else:
return np.sign(beta)

@staticmethod
def add_reg_grad(grad, lam):
def wrapped(beta, *args):
return grad(beta, *args) + lam * L1.gradient(beta)
return wrapped
def hessian(self, beta):
if np.any(np.isclose(beta, 0)):
raise ValueError('l1 norm is not twice differentiable at 0!')
return np.zeros((beta.shape[0], beta.shape[0]))

def proximal_operator(self, beta, t):
z = np.maximum(0, beta - t) - np.maximum(0, -beta - t)
return z


class ElasticNet(Regularizer):
"""Elastic net regularization."""
name = 'elastic_net'

def __init__(self, weight=0.5):
self.weight = weight
self.l1 = L1()
self.l2 = L2()

def _weighted(self, left, right):
return self.weight * left + (1 - self.weight) * right

def f(self, beta):
return self._weighted(self.l1.f(beta), self.l2.f(beta))

def gradient(self, beta):
return self._weighted(self.l1.gradient(beta), self.l2.gradient(beta))

def hessian(self, beta):
return self._weighted(self.l1.hessian(beta), self.l2.hessian(beta))

def proximal_operator(self, beta, t):
"""See notebooks/ElasticNetProximalOperatorDerivation.ipynb for derivation."""
g = self.weight * t

_regularizers = {
'l1': L1,
'l2': L2,
}
@np.vectorize
def func(b):
if b <= g:
return 0
return (b - g * np.sign(b)) / (t - g + 1)
return beta
2 changes: 1 addition & 1 deletion dask_glm/tests/test_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ def test_admm_with_large_lamduh(N, p, nchunks):
y = make_y(X, beta=np.array(beta), chunks=(N // nchunks,))

X, y = persist(X, y)
z = admm(X, y, regularizer=L1, lamduh=1e4, rho=20, max_iter=500)
z = admm(X, y, regularizer=L1(), lamduh=1e4, rho=20, max_iter=500)

assert np.allclose(z, np.zeros(p), atol=1e-4)
4 changes: 2 additions & 2 deletions dask_glm/tests/test_algos_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dask_glm.algorithms import (newton, bfgs, proximal_grad,
gradient_descent, admm)
from dask_glm.families import Logistic, Normal, Poisson
from dask_glm.regularizers import L1, L2
from dask_glm.regularizers import Regularizer
from dask_glm.utils import sigmoid, make_y


Expand Down Expand Up @@ -89,7 +89,7 @@ def test_basic_unreg_descent(func, kwargs, N, nchunks, family):
@pytest.mark.parametrize('nchunks', [1, 10])
@pytest.mark.parametrize('family', [Logistic, Normal, Poisson])
@pytest.mark.parametrize('lam', [0.01, 1.2, 4.05])
@pytest.mark.parametrize('reg', [L1, L2])
@pytest.mark.parametrize('reg', [r() for r in Regularizer.__subclasses__()])
def test_basic_reg_descent(func, kwargs, N, nchunks, family, lam, reg):
beta = np.random.normal(size=2)
M = len(beta)
Expand Down
7 changes: 3 additions & 4 deletions dask_glm/tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

from dask_glm.estimators import LogisticRegression, LinearRegression, PoissonRegression
from dask_glm.datasets import make_classification, make_regression, make_poisson
from dask_glm.algorithms import _solvers
from dask_glm.regularizers import _regularizers
from dask_glm.regularizers import Regularizer


@pytest.fixture(params=_solvers.keys())
@pytest.fixture(params=[r() for r in Regularizer.__subclasses__()])
def solver(request):
"""Parametrized fixture for all the solver names"""
return request.param


@pytest.fixture(params=_regularizers.keys())
@pytest.fixture(params=[r() for r in Regularizer.__subclasses__()])
def regularizer(request):
"""Parametrized fixture for all the regularizer names"""
return request.param
Expand Down
Loading

0 comments on commit b2b6f10

Please sign in to comment.