This repository has been archived by the owner on Jun 22, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
462 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
try: | ||
import catboost as ctb | ||
from catboost import CatBoostClassifier | ||
from steppy.base import BaseTransformer | ||
from steppy.utils import get_logger | ||
|
||
from toolkit.sklearn_transformers.models import MultilabelEstimators | ||
from toolkit.utils import SteppyToolkitError | ||
except ImportError as e: | ||
msg = 'SteppyToolkitError: you have missing modules. Install requirements specific to catboost_transformers.' \ | ||
'Use this file: toolkit/catboost_transformers/requirements.txt' | ||
raise SteppyToolkitError(msg) from e | ||
|
||
logger = get_logger() | ||
|
||
|
||
class CatboostClassifierMultilabel(MultilabelEstimators): | ||
@property | ||
def estimator(self): | ||
return CatBoostClassifier | ||
|
||
|
||
class CatBoost(BaseTransformer): | ||
def __init__(self, **kwargs): | ||
super().__init__() | ||
self.estimator = ctb.CatBoostClassifier(**kwargs) | ||
|
||
def fit(self, | ||
X, y, | ||
X_valid, y_valid, | ||
feature_names=None, | ||
categorical_features=None, | ||
**kwargs): | ||
|
||
logger.info('Catboost, train data shape {}'.format(X.shape)) | ||
logger.info('Catboost, validation data shape {}'.format(X_valid.shape)) | ||
logger.info('Catboost, train labels shape {}'.format(y.shape)) | ||
logger.info('Catboost, validation labels shape {}'.format(y_valid.shape)) | ||
|
||
categorical_indeces = self._get_categorical_indices(feature_names, categorical_features) | ||
self.estimator.fit(X, y, | ||
eval_set=(X_valid, y_valid), | ||
cat_features=categorical_indeces) | ||
return self | ||
|
||
def transform(self, X, **kwargs): | ||
prediction = self.estimator.predict_proba(X)[:, 1] | ||
return {'prediction': prediction} | ||
|
||
def load(self, filepath): | ||
self.estimator.load_model(filepath) | ||
return self | ||
|
||
def persist(self, filepath): | ||
self.estimator.save_model(filepath) | ||
|
||
def _get_categorical_indices(self, feature_names, categorical_features): | ||
if categorical_features: | ||
return [feature_names.index(feature) for feature in categorical_features] | ||
else: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
catboost | ||
steppy |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
try: | ||
import lightgbm as lgb | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn.externals import joblib | ||
from steppy.base import BaseTransformer | ||
from steppy.utils import get_logger | ||
|
||
from toolkit.utils import SteppyToolkitError | ||
except ImportError as e: | ||
msg = 'SteppyToolkitError: you have missing modules. Install requirements specific to lightgbm_transformers.' \ | ||
'Use this file: toolkit/lightgbm_transformers/requirements.txt' | ||
raise SteppyToolkitError(msg) from e | ||
|
||
logger = get_logger() | ||
|
||
|
||
class LightGBM(BaseTransformer): | ||
""" | ||
Accepts three dictionaries that reflects LightGBM API: | ||
- booster_parameters -> parameters of the Booster | ||
See: https://lightgbm.readthedocs.io/en/latest/Parameters.html | ||
- dataset_parameters -> parameters of the lightgbm.Dataset class | ||
See: https://lightgbm.readthedocs.io/en/latest/Python-API.html#data-structure-api | ||
- training_parameters -> parameters of the lightgbm.train function | ||
See: https://lightgbm.readthedocs.io/en/latest/Python-API.html#training-api | ||
""" | ||
def __init__(self, | ||
booster_parameters=None, | ||
dataset_parameters=None, | ||
training_parameters=None): | ||
super().__init__() | ||
logger.info('initializing LightGBM transformer') | ||
if booster_parameters is not None: | ||
isinstance(booster_parameters, dict), 'LightGBM transformer: booster_parameters must be dict, ' \ | ||
'got {} instead'.format(type(booster_parameters)) | ||
if dataset_parameters is not None: | ||
isinstance(dataset_parameters, dict), 'LightGBM transformer: dataset_parameters must be dict, ' \ | ||
'got {} instead'.format(type(dataset_parameters)) | ||
if training_parameters is not None: | ||
isinstance(training_parameters, dict), 'LightGBM transformer: training_parameters must be dict, ' \ | ||
'got {} instead'.format(type(training_parameters)) | ||
|
||
self.booster_parameters = booster_parameters or {} | ||
self.dataset_parameters = dataset_parameters or {} | ||
self.training_parameters = training_parameters or {} | ||
|
||
def fit(self, X, y, X_valid, y_valid): | ||
self._check_target_shape_and_type(y, 'y') | ||
self._check_target_shape_and_type(y_valid, 'y_valid') | ||
y = self._format_target(y) | ||
y_valid = self._format_target(y_valid) | ||
|
||
logger.info('LightGBM transformer, train data shape {}'.format(X.shape)) | ||
logger.info('LightGBM transformer, validation data shape {}'.format(X_valid.shape)) | ||
logger.info('LightGBM transformer, train labels shape {}'.format(y.shape)) | ||
logger.info('LightGBM transformer, validation labels shape {}'.format(y_valid.shape)) | ||
|
||
data_train = lgb.Dataset(data=X, | ||
label=y, | ||
**self.dataset_parameters) | ||
data_valid = lgb.Dataset(data=X_valid, | ||
label=y_valid, | ||
**self.dataset_parameters) | ||
self.estimator = lgb.train(params=self.booster_parameters, | ||
train_set=data_train, | ||
valid_sets=[data_train, data_valid], | ||
valid_names=['data_train', 'data_valid'], | ||
**self.training_parameters) | ||
return self | ||
|
||
def transform(self, X, y=None): | ||
prediction = self.estimator.predict(X) | ||
return {'prediction': prediction} | ||
|
||
def load(self, filepath): | ||
self.estimator = joblib.load(filepath) | ||
return self | ||
|
||
def persist(self, filepath): | ||
joblib.dump(self.estimator, filepath) | ||
|
||
def _check_target_shape_and_type(self, target, name): | ||
if not any([isinstance(target, obj_type) for obj_type in [pd.Series, np.ndarray, list]]): | ||
msg = '"target" must be "numpy.ndarray" or "Pandas.Series" or "list", got {} instead.'.format(type(target)) | ||
raise SteppyToolkitError(msg) | ||
if not isinstance(target, list): | ||
assert len(target.shape) == 1, '"{}" must be 1-D. It is {}-D instead.'.format(name, len(target.shape)) | ||
|
||
def _format_target(self, target): | ||
if isinstance(target, pd.Series): | ||
return target.values | ||
elif isinstance(target, np.ndarray): | ||
return target | ||
elif isinstance(target, list): | ||
return np.array(target) | ||
else: | ||
raise TypeError( | ||
'"target" must be "numpy.ndarray" or "Pandas.Series" or "list", got {} instead.'.format( | ||
type(target))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
attrdict | ||
lightgbm | ||
numpy | ||
pandas | ||
sklearn | ||
steppy |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.