Skip to content

Commit

Permalink
chore: use NumPower (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcharytoniuk authored Dec 26, 2024
1 parent 646b1a2 commit ca59648
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions src/Regressors/Ridge.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator;
use Rubix\ML\Exceptions\InvalidArgumentException;
use Rubix\ML\Exceptions\RuntimeException;
use NDArray as nd;

use function is_null;

Expand Down Expand Up @@ -60,6 +61,8 @@ class Ridge implements Estimator, Learner, RanksFeatures, Persistable
*/
protected ?Vector $coefficients = null;

protected ?nd $coefficientsNd = null;

/**
* @param float $l2Penalty
* @throws InvalidArgumentException
Expand Down Expand Up @@ -161,7 +164,7 @@ public function train(Dataset $dataset) : void
$biases = Matrix::ones($dataset->numSamples(), 1);

$x = Matrix::build($dataset->samples())->augmentLeft($biases);
$y = Vector::build($dataset->labels());
$y = nd::array($dataset->labels());

/** @var int<0,max> $nHat */
$nHat = $x->n() - 1;
Expand All @@ -170,15 +173,18 @@ public function train(Dataset $dataset) : void

array_unshift($penalties, 0.0);

$penalties = Matrix::diagonal($penalties);
$penalties = nd::array(Matrix::diagonal($penalties)->asArray());

$xNp = nd::array($x->asArray());
$xT = nd::transpose($xNp);

$xT = $x->transpose();
$xMul = nd::matmul($xT, $xNp);
$xMulAdd = nd::add($xMul, $penalties);
$xMulAddInv = nd::inv($xMulAdd);
$xtDotY = nd::dot($xT, $y);

$coefficients = $xT->matmul($x)
->add($penalties)
->inverse()
->dot($xT->dot($y))
->asArray();
$this->coefficientsNd = nd::dot($xMulAddInv, $xtDotY);
$coefficients = $this->coefficientsNd->toArray();

$this->bias = (float) array_shift($coefficients);
$this->coefficients = Vector::quick($coefficients);
Expand All @@ -199,10 +205,10 @@ public function predict(Dataset $dataset) : array

DatasetHasDimensionality::with($dataset, count($this->coefficients))->check();

return Matrix::build($dataset->samples())
->dot($this->coefficients)
->add($this->bias)
->asArray();
$datasetNd = nd::array($dataset->samples());
$datasetDotCoefficients = nd::dot($datasetNd, $this->coefficientsNd);

return nd::add($datasetDotCoefficients, $this->bias)->toArray();
}

/**
Expand Down

0 comments on commit ca59648

Please sign in to comment.