Skip to content

Commit

Permalink
Optimize Gradient Boost
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Jul 5, 2021
1 parent b0c17da commit 23d2abf
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions src/Regressors/GradientBoost.php
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class GradientBoost implements Estimator, Learner, RanksFeatures, Verbose, Persi
*
* @var int
*/
protected const MIN_SUBSAMPLE = 1;
protected const MIN_SUBSAMPLE = 2;

/**
* The regressor that will fix up the error residuals of the *weak* base learner.
Expand Down Expand Up @@ -392,12 +392,12 @@ public function train(Dataset $dataset) : void

$this->base->train($training);

$out = $prevOut = $this->base->predict($training);
$out = $this->base->predict($training);

$targets = $training->labels();

if (!$testing->empty()) {
$prevOutTest = $this->base->predict($testing);
$outTest = $this->base->predict($testing);
}

$p = max(self::MIN_SUBSAMPLE, (int) round($this->ratio * $m));
Expand Down Expand Up @@ -432,16 +432,16 @@ public function train(Dataset $dataset) : void

$predictions = $booster->predict($training);

$out = array_map([$this, 'updateOut'], $predictions, $prevOut);
$out = array_map([$this, 'updateOut'], $predictions, $out);

$this->losses[$epoch] = $loss;

$this->ensemble[] = $booster;

if (isset($prevOutTest)) {
if (isset($outTest)) {
$predictions = $booster->predict($testing);

$outTest = array_map([$this, 'updateOut'], $predictions, $prevOutTest);
$outTest = array_map([$this, 'updateOut'], $predictions, $outTest);

$score = $this->metric->score($outTest, $testing->labels());

Expand Down Expand Up @@ -470,18 +470,13 @@ public function train(Dataset $dataset) : void
if ($delta >= $this->window) {
break;
}

$prevOutTest = $outTest;
}

if (abs($prevLoss - $loss) < $this->minChange) {
break;
}

if ($epoch < $this->estimators) {
$prevOut = $out;
$prevLoss = $loss;
}
$prevLoss = $loss;
}

if ($this->scores and end($this->scores) <= $bestScore) {
Expand Down Expand Up @@ -518,10 +513,7 @@ public function predict(Dataset $dataset) : array
foreach ($this->ensemble as $estimator) {
$predictions = $estimator->predict($dataset);

/** @var int $j */
foreach ($predictions as $j => $prediction) {
$out[$j] += $this->rate * $prediction;
}
$out = array_map([$this, 'updateOut'], $predictions, $out);
}

return $out;
Expand All @@ -542,7 +534,9 @@ public function featureImportances() : array
$importances = array_fill(0, $this->featureCount, 0.0);

foreach ($this->ensemble as $tree) {
foreach ($tree->featureImportances() as $column => $importance) {
$importances = $tree->featureImportances();

foreach ($importances as $column => $importance) {
$importances[$column] += $importance;
}
}
Expand All @@ -560,12 +554,12 @@ public function featureImportances() : array
* Compute the output for an iteration.
*
* @param float $prediction
* @param float $prevOut
* @param float $out
* @return float
*/
protected function updateOut(float $prediction, float $prevOut) : float
protected function updateOut(float $prediction, float $out) : float
{
return $this->rate * $prediction + $prevOut;
return $this->rate * $prediction + $out;
}

/**
Expand Down

0 comments on commit 23d2abf

Please sign in to comment.