From d43a5f2cd714277c10f74096cc5776ce4c828354 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Thu, 23 May 2024 12:44:44 -0500 Subject: [PATCH] 2.5 (#336) * Initial commit (#299) * Vantage Tree (#300) * Initial commit * Better testing * Improve the docs * Rename benchmark * Explicitly import max() function * Fix coding style * Wrapper interface (#314) * Add Wrapper interface for models wrappers * Add WrapperAware trait * Fix PhpDoc * Revert "Add WrapperAware trait" This reverts commit 241abc4317eec701211b7a88a17a1b610c366dfe. * Rename Wrapper interface to EstimatorWrapper * PHP CS fix * Swoole Backend (#312) * add Swoole backend * phpstan: ignore swoole * feat: swoole process scheduler * fix(swoole): redo tasks when hash collision happens * chore(swoole): make sure coroutines are at the root of the scheduler * chore(swoole): set affinity / bind worker to a specific CPU core * chore(swoole): use igbinary if available * fix: remove comment * fix(swoole): worker cpu affinity * fix(swoole): cpu num * feat: scheduler improvements * style * chore(swoole): remove unnecessary atomics * chore(swoole): php backwards compatibility * fix: phpstan, socket message size * fix: uncomment test * style: composer fix * Plus plus check (#317) * Initial commit * Allow deltas in units tests * Swoole docs (#326) * add Swoole backend * phpstan: ignore swoole * feat: swoole process scheduler * fix(swoole): redo tasks when hash collision happens * chore(swoole): make sure coroutines are at the root of the scheduler * chore(swoole): set affinity / bind worker to a specific CPU core * chore(swoole): use igbinary if available * fix: remove comment * fix(swoole): worker cpu affinity * fix(swoole): cpu num * feat: scheduler improvements * style * chore(swoole): remove unnecessary atomics * chore(swoole): php backwards compatibility * fix: phpstan, socket message size * fix: uncomment test * style: composer fix * docs: Swoole backend * Fix coding style and composer.lock * fix(swoole): setAffinity does not exist on some versions of Swoole (#327) * Back out Swoole Backend code * Bump version --------- Co-authored-by: Ronan Giron Co-authored-by: Mateusz Charytoniuk --- CHANGELOG.md | 7 + benchmarks/Graph/Trees/VantageTreeBench.php | 49 +++ composer.json | 2 +- docs/datasets/generators/blob.md | 12 +- docs/graph/trees/vantage-tree.md | 28 ++ mkdocs.yml | 1 + phpunit.xml | 15 +- src/Clusterers/Seeders/PlusPlus.php | 7 + src/Datasets/Generators/Blob.php | 42 +++ src/EstimatorWrapper.php | 20 + src/Graph/Nodes/VantagePoint.php | 160 ++++++++ src/Graph/Trees/VantageTree.php | 346 ++++++++++++++++++ src/GridSearch.php | 2 +- src/PersistentModel.php | 2 +- src/Pipeline.php | 2 +- src/constants.php | 2 +- tests/BootstrapAggregatorTest.php | 3 - tests/Classifiers/OneVsRestTest.php | 3 - tests/Classifiers/RadiusNeighborsTest.php | 6 +- tests/Classifiers/RandomForestTest.php | 3 - tests/CommitteeMachineTest.php | 3 - tests/CrossValidation/KFoldTest.php | 3 - tests/CrossValidation/LeavePOutTest.php | 3 - tests/CrossValidation/MonteCarloTest.php | 3 - tests/Datasets/Generators/BlobTest.php | 21 ++ .../{SQTableTest.php => SQLTableTest.php} | 0 tests/Graph/Nodes/VantagePointTest.php | 100 +++++ tests/Graph/Trees/VantageTreeTest.php | 108 ++++++ tests/GridSearchTest.php | 2 - tests/Transformers/ImageRotatorTest.php | 2 +- tests/Transformers/MaxAbsoluteScalerTest.php | 6 +- tests/Transformers/RobustStandardizerTest.php | 2 +- 32 files changed, 926 insertions(+), 39 deletions(-) create mode 100644 benchmarks/Graph/Trees/VantageTreeBench.php create mode 100644 docs/graph/trees/vantage-tree.md create mode 100644 src/EstimatorWrapper.php create mode 100644 src/Graph/Nodes/VantagePoint.php create mode 100644 src/Graph/Trees/VantageTree.php rename tests/Extractors/{SQTableTest.php => SQLTableTest.php} (100%) create mode 100644 tests/Graph/Nodes/VantagePointTest.php create mode 100644 tests/Graph/Trees/VantageTreeTest.php diff --git a/CHANGELOG.md b/CHANGELOG.md index 3acb07ca6..6faa04d1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +- 2.5.0 + - Added Vantage Point Spatial tree + - Blob Generator can now `simulate()` a Dataset object + - Added Wrapper interface + - Plus Plus added check for min number of sample seeds + - LOF prevent div by 0 local reachability density + - 2.4.1 - Sentence Tokenizer fix Arabic and Farsi language support - Optimize online variance updating diff --git a/benchmarks/Graph/Trees/VantageTreeBench.php b/benchmarks/Graph/Trees/VantageTreeBench.php new file mode 100644 index 000000000..b2e878256 --- /dev/null +++ b/benchmarks/Graph/Trees/VantageTreeBench.php @@ -0,0 +1,49 @@ + new Blob([5.0, 3.42, 1.46, 0.24], [0.35, 0.38, 0.17, 0.1]), + 'Iris-versicolor' => new Blob([5.94, 2.77, 4.26, 1.33], [0.51, 0.31, 0.47, 0.2]), + 'Iris-virginica' => new Blob([6.59, 2.97, 5.55, 2.03], [0.63, 0.32, 0.55, 0.27]), + ]); + + $this->dataset = $generator->generate(self::DATASET_SIZE); + + $this->tree = new VantageTree(30); + } + + /** + * @Subject + * @Iterations(3) + * @OutputTimeUnit("seconds", precision=3) + */ + public function grow() : void + { + $this->tree->grow($this->dataset); + } +} diff --git a/composer.json b/composer.json index 8cb313510..b17ac2739 100644 --- a/composer.json +++ b/composer.json @@ -79,7 +79,7 @@ "@test", "@check" ], - "analyze": "phpstan analyse -c phpstan.neon", + "analyze": "phpstan analyse -c phpstan.neon --memory-limit 1G", "benchmark": "phpbench run --report=aggregate", "check": [ "@putenv PHP_CS_FIXER_IGNORE_ENV=1", diff --git a/docs/datasets/generators/blob.md b/docs/datasets/generators/blob.md index a0a8cbe4e..e99bf65ab 100644 --- a/docs/datasets/generators/blob.md +++ b/docs/datasets/generators/blob.md @@ -17,8 +17,16 @@ A normally distributed (Gaussian) n-dimensional blob of samples centered at a gi ```php use Rubix\ML\Datasets\Generators\Blob; -$generator = new Blob([-1.2, -5., 2.6, 0.8, 10.], 0.25); +$generator = new Blob([-1.2, -5.0, 2.6, 0.8, 10.0], 0.25); ``` ## Additional Methods -This generator does not have any additional methods. +Fit a Blob generator to the samples in a dataset. +```php +public static simulate(Dataset $dataset) : self +``` + +Return the center coordinates of the Blob. +```php +public center() : array +``` diff --git a/docs/graph/trees/vantage-tree.md b/docs/graph/trees/vantage-tree.md new file mode 100644 index 000000000..de6aa5ef4 --- /dev/null +++ b/docs/graph/trees/vantage-tree.md @@ -0,0 +1,28 @@ +[source] + +# Vantage Tree +A Vantage Point Tree is a binary spatial tree that divides samples by their distance from the center of a cluster called the *vantage point*. Samples that are closer to the vantage point will be put into one branch of the tree while samples that are farther away will be put into the other branch. + +**Interfaces:** Binary Tree, Spatial + +**Data Type Compatibility:** Depends on distance kernel + +## Parameters +| # | Param | Default | Type | Description | +|---|---|---|---|---| +| 1 | max leaf size | 30 | int | The maximum number of samples that each leaf node can contain. | +| 2 | kernel | Euclidean | Distance | The distance kernel used to compute the distance between sample points. | + +## Example +```php +use Rubix\ML\Graph\Trees\VantageTree; +use Rubix\ML\Kernels\Distance\Euclidean; + +$tree = new VantageTree(30, new Euclidean()); +``` + +## Additional Methods +This tree does not have any additional methods. + +### References +>- P. N. Yianilos. (1993). Data Structures and Algorithms for Nearest Neighbor Search in General Metric Spaces. diff --git a/mkdocs.yml b/mkdocs.yml index b1752ad3e..015db2181 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -201,6 +201,7 @@ nav: - Trees: - Ball Tree: graph/trees/ball-tree.md - K-d Tree: graph/trees/k-d-tree.md + - Vantage Tree: graph/trees/vantage-tree.md - Kernels: - Distance: - Canberra: kernels/distance/canberra.md diff --git a/phpunit.xml b/phpunit.xml index 33e100832..f2656a836 100644 --- a/phpunit.xml +++ b/phpunit.xml @@ -1,5 +1,18 @@ - + src diff --git a/src/Clusterers/Seeders/PlusPlus.php b/src/Clusterers/Seeders/PlusPlus.php index f69d151a2..ddadd1877 100644 --- a/src/Clusterers/Seeders/PlusPlus.php +++ b/src/Clusterers/Seeders/PlusPlus.php @@ -6,6 +6,7 @@ use Rubix\ML\Kernels\Distance\Distance; use Rubix\ML\Kernels\Distance\Euclidean; use Rubix\ML\Specifications\DatasetIsNotEmpty; +use Rubix\ML\Exceptions\RuntimeException; use function count; @@ -49,12 +50,18 @@ public function __construct(?Distance $kernel = null) * * @param Dataset $dataset * @param int $k + * @throws RuntimeException * @return list> */ public function seed(Dataset $dataset, int $k) : array { DatasetIsNotEmpty::with($dataset)->check(); + if ($k > $dataset->numSamples()) { + throw new RuntimeException("Cannot seed $k clusters with only " + . $dataset->numSamples() . ' samples.'); + } + $centroids = $dataset->randomSubset(1)->samples(); while (count($centroids) < $k) { diff --git a/src/Datasets/Generators/Blob.php b/src/Datasets/Generators/Blob.php index e9f03ba41..f79778173 100644 --- a/src/Datasets/Generators/Blob.php +++ b/src/Datasets/Generators/Blob.php @@ -4,10 +4,14 @@ use Tensor\Matrix; use Tensor\Vector; +use Rubix\ML\DataType; +use Rubix\ML\Helpers\Stats; +use Rubix\ML\Datasets\Dataset; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Exceptions\InvalidArgumentException; use function count; +use function sqrt; /** * Blob @@ -37,6 +41,34 @@ class Blob implements Generator */ protected $stdDev; + /** + * Fit a Blob generator to the samples in a dataset. + * + * @param Dataset $dataset + * @throws InvalidArgumentException + * @return self + */ + public static function simulate(Dataset $dataset) : self + { + $features = $dataset->featuresByType(DataType::continuous()); + + if (count($features) !== $dataset->numFeatures()) { + throw new InvalidArgumentException('Dataset must only contain' + . ' continuous features.'); + } + + $means = $stdDevs = []; + + foreach ($features as $values) { + [$mean, $variance] = Stats::meanVar($values); + + $means[] = $mean; + $stdDevs[] = sqrt($variance); + } + + return new self($means, $stdDevs); + } + /** * @param (int|float)[] $center * @param int|float|(int|float)[] $stdDev @@ -74,6 +106,16 @@ public function __construct(array $center = [0, 0], $stdDev = 1.0) $this->stdDev = $stdDev; } + /** + * Return the center coordinates of the Blob. + * + * @return list + */ + public function center() : array + { + return $this->center->asArray(); + } + /** * Return the dimensionality of the data this generates. * diff --git a/src/EstimatorWrapper.php b/src/EstimatorWrapper.php new file mode 100644 index 000000000..aafb3ac8e --- /dev/null +++ b/src/EstimatorWrapper.php @@ -0,0 +1,20 @@ + + */ + protected $center; + + /** + * The radius of the centroid. + * + * @var float + */ + protected $radius; + + /** + * The left and right splits of the training data. + * + * @var array{Labeled,Labeled}|null + */ + protected $subsets; + + /** + * Factory method to build a hypersphere by splitting the dataset into left and right clusters. + * + * @param Labeled $dataset + * @param Distance $kernel + * @return self + */ + public static function split(Labeled $dataset, Distance $kernel) : self + { + $center = []; + + foreach ($dataset->features() as $column => $values) { + if ($dataset->featureType($column)->isContinuous()) { + $center[] = Stats::mean($values); + } else { + $center[] = argmax(array_count_values($values)); + } + } + + $distances = []; + + foreach ($dataset->samples() as $sample) { + $distances[] = $kernel->compute($sample, $center); + } + + $threshold = Stats::median($distances); + + $samples = $dataset->samples(); + $labels = $dataset->labels(); + + $leftSamples = $leftLabels = $rightSamples = $rightLabels = []; + + foreach ($distances as $i => $distance) { + if ($distance <= $threshold) { + $leftSamples[] = $samples[$i]; + $leftLabels[] = $labels[$i]; + } else { + $rightSamples[] = $samples[$i]; + $rightLabels[] = $labels[$i]; + } + } + + $radius = max($distances) ?: 0.0; + + return new self($center, $radius, [ + Labeled::quick($leftSamples, $leftLabels), + Labeled::quick($rightSamples, $rightLabels), + ]); + } + + /** + * @param list $center + * @param float $radius + * @param array{Labeled,Labeled} $subsets + */ + public function __construct(array $center, float $radius, array $subsets) + { + $this->center = $center; + $this->radius = $radius; + $this->subsets = $subsets; + } + + /** + * Return the center vector. + * + * @return list + */ + public function center() : array + { + return $this->center; + } + + /** + * Return the radius of the centroid. + * + * @return float + */ + public function radius() : float + { + return $this->radius; + } + + /** + * Return the left and right subsets of the training data. + * + * @throws RuntimeException + * @return array{\Rubix\ML\Datasets\Labeled,\Rubix\ML\Datasets\Labeled} + */ + public function subsets() : array + { + if (!isset($this->subsets)) { + throw new RuntimeException('Subsets property does not exist.'); + } + + return $this->subsets; + } + + /** + * Does the hypersphere reduce to a single point? + * + * @return bool + */ + public function isPoint() : bool + { + return $this->radius === 0.0; + } + + /** + * Remove the left and right splits of the training data. + */ + public function cleanup() : void + { + unset($this->subsets); + } +} diff --git a/src/Graph/Trees/VantageTree.php b/src/Graph/Trees/VantageTree.php new file mode 100644 index 000000000..a1c0d648b --- /dev/null +++ b/src/Graph/Trees/VantageTree.php @@ -0,0 +1,346 @@ +maxLeafSize = $maxLeafSize; + $this->kernel = $kernel ?? new Euclidean(); + } + + /** + * Return the height of the tree i.e. the number of levels. + * + * @return int + */ + public function height() : int + { + return $this->root ? $this->root->height() : 0; + } + + /** + * Return the balance factor of the tree. A balanced tree will have + * a factor of 0 whereas an imbalanced tree will either be positive + * or negative indicating the direction and degree of the imbalance. + * + * @return int + */ + public function balance() : int + { + return $this->root ? $this->root->balance() : 0; + } + + /** + * Is the tree bare? + * + * @return bool + */ + public function bare() : bool + { + return !$this->root; + } + + /** + * Return the distance kernel used to compute distances. + * + * @return Distance + */ + public function kernel() : Distance + { + return $this->kernel; + } + + /** + * Insert a root node and recursively split the dataset until a terminating + * condition is met. + * + * @internal + * + * @param Labeled $dataset + * @throws InvalidArgumentException + */ + public function grow(Labeled $dataset) : void + { + if (!$dataset instanceof Labeled) { + throw new InvalidArgumentException('Tree requires a labeled dataset.'); + } + + $this->root = VantagePoint::split($dataset, $this->kernel); + + $stack = [$this->root]; + + while ($current = array_pop($stack)) { + [$left, $right] = $current->subsets(); + + $current->cleanup(); + + if ($left->numSamples() > $this->maxLeafSize) { + $node = VantagePoint::split($left, $this->kernel); + + if ($node->isPoint()) { + $current->attachLeft(Clique::terminate($left, $this->kernel)); + } else { + $current->attachLeft($node); + + $stack[] = $node; + } + } elseif (!$left->empty()) { + $current->attachLeft(Clique::terminate($left, $this->kernel)); + } + + if ($right->numSamples() > $this->maxLeafSize) { + $node = VantagePoint::split($right, $this->kernel); + + $current->attachRight($node); + + $stack[] = $node; + } elseif (!$right->empty()) { + $current->attachRight(Clique::terminate($right, $this->kernel)); + } + } + } + + /** + * Run a k nearest neighbors search and return the samples, labels, and + * distances in a 3-tuple. + * + * @param (string|int|float)[] $sample + * @param int $k + * @throws InvalidArgumentException + * @return array> + */ + public function nearest(array $sample, int $k = 1) : array + { + if ($k < 1) { + throw new InvalidArgumentException('K must be' + . " greater than 0, $k given."); + } + + $visited = new SplObjectStorage(); + + $stack = $this->path($sample); + + $samples = $labels = $distances = []; + + while ($current = array_pop($stack)) { + if ($current instanceof VantagePoint) { + $radius = $distances[$k - 1] ?? INF; + + foreach ($current->children() as $child) { + if (!$visited->contains($child)) { + if ($child instanceof Hypersphere) { + $distance = $this->kernel->compute($sample, $child->center()); + + if ($distance - $child->radius() < $radius) { + $stack[] = $child; + + continue; + } + } + + $visited->attach($child); + } + } + + $visited->attach($current); + + continue; + } + + if ($current instanceof Clique) { + $dataset = $current->dataset(); + + foreach ($dataset->samples() as $neighbor) { + $distances[] = $this->kernel->compute($sample, $neighbor); + } + + $samples = array_merge($samples, $dataset->samples()); + $labels = array_merge($labels, $dataset->labels()); + + array_multisort($distances, $samples, $labels); + + if (count($samples) > $k) { + $samples = array_slice($samples, 0, $k); + $labels = array_slice($labels, 0, $k); + $distances = array_slice($distances, 0, $k); + } + + $visited->attach($current); + } + } + + return [$samples, $labels, $distances]; + } + + /** + * Return all samples, labels, and distances within a given radius of a sample. + * + * @param (string|int|float)[] $sample + * @param float $radius + * @throws InvalidArgumentException + * @return array> + */ + public function range(array $sample, float $radius) : array + { + if ($radius <= 0.0) { + throw new InvalidArgumentException('Radius must be' + . " greater than 0, $radius given."); + } + + $samples = $labels = $distances = []; + + $stack = [$this->root]; + + while ($current = array_pop($stack)) { + if ($current instanceof VantagePoint) { + foreach ($current->children() as $child) { + if ($child instanceof Hypersphere) { + $distance = $this->kernel->compute($sample, $child->center()); + + if ($distance - $child->radius() < $radius) { + $stack[] = $child; + } + } + } + + continue; + } + + if ($current instanceof Clique) { + $dataset = $current->dataset(); + + foreach ($dataset->samples() as $i => $neighbor) { + $distance = $this->kernel->compute($sample, $neighbor); + + if ($distance <= $radius) { + $samples[] = $neighbor; + $labels[] = $dataset->label($i); + $distances[] = $distance; + } + } + } + } + + return [$samples, $labels, $distances]; + } + + /** + * Destroy the tree. + */ + public function destroy() : void + { + unset($this->root); + } + + /** + * Return the path of a sample taken from the root node to a leaf node + * in an array. + * + * @param (string|int|float)[] $sample + * @return mixed[] + */ + protected function path(array $sample) : array + { + $current = $this->root; + + $path = []; + + while ($current) { + $path[] = $current; + + if ($current instanceof VantagePoint) { + $left = $current->left(); + $right = $current->right(); + + if ($left instanceof Hypersphere) { + $distance = $this->kernel->compute($sample, $left->center()); + + if ($distance <= $left->radius()) { + $current = $left; + } else { + $current = $right; + } + } + + continue; + } + + break; + } + + return $path; + } + + /** + * Return the string representation of the object. + * + * @return string + */ + public function __toString() : string + { + return "Vantage Tree (max_leaf_size: {$this->maxLeafSize}, kernel: {$this->kernel})"; + } +} diff --git a/src/GridSearch.php b/src/GridSearch.php index 0eaafdd18..67140769b 100644 --- a/src/GridSearch.php +++ b/src/GridSearch.php @@ -39,7 +39,7 @@ * @package Rubix/ML * @author Andrew DalPino */ -class GridSearch implements Estimator, Learner, Parallel, Verbose, Persistable +class GridSearch implements EstimatorWrapper, Learner, Parallel, Verbose, Persistable { use AutotrackRevisions, Multiprocessing, LoggerAware; diff --git a/src/PersistentModel.php b/src/PersistentModel.php index 40183dbf6..b8b60269f 100644 --- a/src/PersistentModel.php +++ b/src/PersistentModel.php @@ -21,7 +21,7 @@ * @package Rubix/ML * @author Andrew DalPino */ -class PersistentModel implements Estimator, Learner, Probabilistic, Scoring +class PersistentModel implements EstimatorWrapper, Learner, Probabilistic, Scoring { /** * The persistable base learner. diff --git a/src/Pipeline.php b/src/Pipeline.php index 62abc5c9b..53b79c8ec 100644 --- a/src/Pipeline.php +++ b/src/Pipeline.php @@ -25,7 +25,7 @@ * @package Rubix/ML * @author Andrew DalPino */ -class Pipeline implements Online, Probabilistic, Scoring, Persistable +class Pipeline implements Online, Probabilistic, Scoring, Persistable, EstimatorWrapper { use AutotrackRevisions; diff --git a/src/constants.php b/src/constants.php index 7f4474967..5dda0d669 100644 --- a/src/constants.php +++ b/src/constants.php @@ -9,7 +9,7 @@ * * @var literal-string */ - const VERSION = '2.4'; + const VERSION = '2.5'; /** * A very small positive number, sometimes used in substitution of 0. diff --git a/tests/BootstrapAggregatorTest.php b/tests/BootstrapAggregatorTest.php index c8062dc5f..59fae0487 100644 --- a/tests/BootstrapAggregatorTest.php +++ b/tests/BootstrapAggregatorTest.php @@ -7,7 +7,6 @@ use Rubix\ML\Estimator; use Rubix\ML\Persistable; use Rubix\ML\EstimatorType; -use Rubix\ML\Backends\Serial; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\BootstrapAggregator; use Rubix\ML\Regressors\RegressionTree; @@ -115,8 +114,6 @@ public function params() : void */ public function trainPredict() : void { - $this->estimator->setBackend(new Serial()); - $training = $this->generator->generate(self::TRAIN_SIZE); $testing = $this->generator->generate(self::TEST_SIZE); diff --git a/tests/Classifiers/OneVsRestTest.php b/tests/Classifiers/OneVsRestTest.php index 26bcc2d24..f74d8de5e 100644 --- a/tests/Classifiers/OneVsRestTest.php +++ b/tests/Classifiers/OneVsRestTest.php @@ -9,7 +9,6 @@ use Rubix\ML\Persistable; use Rubix\ML\Probabilistic; use Rubix\ML\EstimatorType; -use Rubix\ML\Backends\Serial; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Classifiers\OneVsRest; use Rubix\ML\Classifiers\GaussianNB; @@ -81,8 +80,6 @@ protected function setUp() : void $this->estimator = new OneVsRest(new GaussianNB()); - $this->estimator->setBackend(new Serial()); - $this->metric = new FBeta(); srand(self::RANDOM_SEED); diff --git a/tests/Classifiers/RadiusNeighborsTest.php b/tests/Classifiers/RadiusNeighborsTest.php index 0b4f774cb..5a2878bcb 100644 --- a/tests/Classifiers/RadiusNeighborsTest.php +++ b/tests/Classifiers/RadiusNeighborsTest.php @@ -10,7 +10,7 @@ use Rubix\ML\EstimatorType; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; -use Rubix\ML\Graph\Trees\BallTree; +use Rubix\ML\Graph\Trees\VantageTree; use Rubix\ML\Datasets\Generators\Blob; use Rubix\ML\Classifiers\RadiusNeighbors; use Rubix\ML\Datasets\Generators\Agglomerate; @@ -79,7 +79,7 @@ protected function setUp() : void 'blue' => new Blob([0, 32, 255], 30.0), ], [0.5, 0.2, 0.3]); - $this->estimator = new RadiusNeighbors(60.0, true, '?', new BallTree()); + $this->estimator = new RadiusNeighbors(60.0, true, '?', new VantageTree()); $this->metric = new FBeta(); @@ -142,7 +142,7 @@ public function params() : void 'radius' => 60.0, 'weighted' => true, 'outlier class' => '?', - 'tree' => new BallTree(), + 'tree' => new VantageTree(), ]; $this->assertEquals($expected, $this->estimator->params()); diff --git a/tests/Classifiers/RandomForestTest.php b/tests/Classifiers/RandomForestTest.php index 9572a9cf2..47d179b0c 100644 --- a/tests/Classifiers/RandomForestTest.php +++ b/tests/Classifiers/RandomForestTest.php @@ -9,7 +9,6 @@ use Rubix\ML\Probabilistic; use Rubix\ML\RanksFeatures; use Rubix\ML\EstimatorType; -use Rubix\ML\Backends\Serial; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Classifiers\RandomForest; use Rubix\ML\Datasets\Generators\Blob; @@ -82,8 +81,6 @@ protected function setUp() : void $this->estimator = new RandomForest(new ClassificationTree(3), 50, 0.2, true); - $this->estimator->setBackend(new Serial()); - $this->metric = new FBeta(); srand(self::RANDOM_SEED); diff --git a/tests/CommitteeMachineTest.php b/tests/CommitteeMachineTest.php index f9af35bfb..722645147 100644 --- a/tests/CommitteeMachineTest.php +++ b/tests/CommitteeMachineTest.php @@ -8,7 +8,6 @@ use Rubix\ML\Estimator; use Rubix\ML\Persistable; use Rubix\ML\EstimatorType; -use Rubix\ML\Backends\Serial; use Rubix\ML\CommitteeMachine; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Classifiers\GaussianNB; @@ -135,8 +134,6 @@ public function params() : void */ public function trainPredict() : void { - $this->estimator->setBackend(new Serial()); - $training = $this->generator->generate(self::TRAIN_SIZE); $testing = $this->generator->generate(self::TEST_SIZE); diff --git a/tests/CrossValidation/KFoldTest.php b/tests/CrossValidation/KFoldTest.php index 805236cbe..f1cc000b9 100644 --- a/tests/CrossValidation/KFoldTest.php +++ b/tests/CrossValidation/KFoldTest.php @@ -3,7 +3,6 @@ namespace Rubix\ML\Tests\CrossValidation; use Rubix\ML\Parallel; -use Rubix\ML\Backends\Serial; use Rubix\ML\CrossValidation\KFold; use Rubix\ML\Datasets\Generators\Blob; use Rubix\ML\CrossValidation\Validator; @@ -54,8 +53,6 @@ protected function setUp() : void $this->validator = new KFold(10); - $this->validator->setBackend(new Serial()); - $this->metric = new Accuracy(); } diff --git a/tests/CrossValidation/LeavePOutTest.php b/tests/CrossValidation/LeavePOutTest.php index dbb54bf57..06ca06fa8 100644 --- a/tests/CrossValidation/LeavePOutTest.php +++ b/tests/CrossValidation/LeavePOutTest.php @@ -3,7 +3,6 @@ namespace Rubix\ML\Tests\CrossValidation; use Rubix\ML\Parallel; -use Rubix\ML\Backends\Serial; use Rubix\ML\Datasets\Generators\Blob; use Rubix\ML\CrossValidation\LeavePOut; use Rubix\ML\CrossValidation\Validator; @@ -54,8 +53,6 @@ protected function setUp() : void $this->validator = new LeavePOut(10); - $this->validator->setBackend(new Serial()); - $this->metric = new Accuracy(); } diff --git a/tests/CrossValidation/MonteCarloTest.php b/tests/CrossValidation/MonteCarloTest.php index ca8967bfb..88f03ba5f 100644 --- a/tests/CrossValidation/MonteCarloTest.php +++ b/tests/CrossValidation/MonteCarloTest.php @@ -3,7 +3,6 @@ namespace Rubix\ML\Tests\CrossValidation; use Rubix\ML\Parallel; -use Rubix\ML\Backends\Serial; use Rubix\ML\Datasets\Generators\Blob; use Rubix\ML\CrossValidation\Validator; use Rubix\ML\CrossValidation\MonteCarlo; @@ -54,8 +53,6 @@ protected function setUp() : void $this->validator = new MonteCarlo(3, 0.2); - $this->validator->setBackend(new Serial()); - $this->metric = new Accuracy(); } diff --git a/tests/Datasets/Generators/BlobTest.php b/tests/Datasets/Generators/BlobTest.php index 4c7342051..130d54143 100644 --- a/tests/Datasets/Generators/BlobTest.php +++ b/tests/Datasets/Generators/BlobTest.php @@ -29,6 +29,19 @@ protected function setUp() : void $this->generator = new Blob([0, 0, 0], 1.0); } + /** + * @test + */ + public function simulate() : void + { + $dataset = $this->generator->generate(100); + + $generator = Blob::simulate($dataset); + + $this->assertInstanceOf(Blob::class, $generator); + $this->assertInstanceOf(Generator::class, $generator); + } + /** * @test */ @@ -38,6 +51,14 @@ public function build() : void $this->assertInstanceOf(Generator::class, $this->generator); } + /** + * @test + */ + public function center() : void + { + $this->assertEquals([0, 0, 0], $this->generator->center()); + } + /** * @test */ diff --git a/tests/Extractors/SQTableTest.php b/tests/Extractors/SQLTableTest.php similarity index 100% rename from tests/Extractors/SQTableTest.php rename to tests/Extractors/SQLTableTest.php diff --git a/tests/Graph/Nodes/VantagePointTest.php b/tests/Graph/Nodes/VantagePointTest.php new file mode 100644 index 000000000..64d9a20a2 --- /dev/null +++ b/tests/Graph/Nodes/VantagePointTest.php @@ -0,0 +1,100 @@ +node = new VantagePoint(self::CENTER, self::RADIUS, $groups); + } + + /** + * @test + */ + public function build() : void + { + $this->assertInstanceOf(VantagePoint::class, $this->node); + $this->assertInstanceOf(Hypersphere::class, $this->node); + $this->assertInstanceOf(BinaryNode::class, $this->node); + $this->assertInstanceOf(Node::class, $this->node); + } + + /** + * @test + */ + public function split() : void + { + $dataset = Labeled::quick(self::SAMPLES, self::LABELS); + + $node = VantagePoint::split($dataset, new Euclidean()); + + $this->assertEquals(self::CENTER, $node->center()); + $this->assertEquals(self::RADIUS, $node->radius()); + } + + /** + * @test + */ + public function center() : void + { + $this->assertSame(self::CENTER, $this->node->center()); + } + + /** + * @test + */ + public function radius() : void + { + $this->assertSame(self::RADIUS, $this->node->radius()); + } + + /** + * @test + */ + public function subsets() : void + { + $expected = [ + Labeled::quick([self::SAMPLES[0]], [self::LABELS[0]]), + Labeled::quick([self::SAMPLES[1]], [self::LABELS[1]]), + ]; + + $this->assertEquals($expected, $this->node->subsets()); + } +} diff --git a/tests/Graph/Trees/VantageTreeTest.php b/tests/Graph/Trees/VantageTreeTest.php new file mode 100644 index 000000000..06c1d75af --- /dev/null +++ b/tests/Graph/Trees/VantageTreeTest.php @@ -0,0 +1,108 @@ +generator = new Agglomerate([ + 'east' => new Blob([5, -2, -2]), + 'west' => new Blob([0, 5, -3]), + ], [0.5, 0.5]); + + $this->tree = new VantageTree(20, new Euclidean()); + + srand(self::RANDOM_SEED); + } + + protected function assertPreConditions() : void + { + $this->assertEquals(0, $this->tree->height()); + } + + /** + * @test + */ + public function build() : void + { + $this->assertInstanceOf(VantageTree::class, $this->tree); + $this->assertInstanceOf(Spatial::class, $this->tree); + $this->assertInstanceOf(BinaryTree::class, $this->tree); + $this->assertInstanceOf(Tree::class, $this->tree); + } + + /** + * @test + */ + public function growNeighborsRange() : void + { + $this->tree->grow($this->generator->generate(self::DATASET_SIZE)); + + $this->assertGreaterThan(2, $this->tree->height()); + + $sample = $this->generator->generate(1)->sample(0); + + [$samples, $labels, $distances] = $this->tree->nearest($sample, 5); + + $this->assertCount(5, $samples); + $this->assertCount(5, $labels); + $this->assertCount(5, $distances); + + $this->assertCount(1, array_unique($labels)); + + [$samples, $labels, $distances] = $this->tree->range($sample, 4.3); + + $this->assertCount(50, $samples); + $this->assertCount(50, $labels); + $this->assertCount(50, $distances); + + $this->assertCount(1, array_unique($labels)); + } + + /** + * @test + */ + public function growWithSameSamples() : void + { + $generator = new Agglomerate([ + 'east' => new Blob([5, -2, 10], 0.0), + ]); + + $dataset = $generator->generate(self::DATASET_SIZE); + + $this->tree->grow($dataset); + + $this->assertEquals(2, $this->tree->height()); + } +} diff --git a/tests/GridSearchTest.php b/tests/GridSearchTest.php index 0a61ddfdc..0bbc2aa4d 100644 --- a/tests/GridSearchTest.php +++ b/tests/GridSearchTest.php @@ -9,7 +9,6 @@ use Rubix\ML\GridSearch; use Rubix\ML\Persistable; use Rubix\ML\EstimatorType; -use Rubix\ML\Backends\Serial; use Rubix\ML\Loggers\BlackHole; use Rubix\ML\CrossValidation\HoldOut; use Rubix\ML\Kernels\Distance\Euclidean; @@ -126,7 +125,6 @@ public function params() : void public function trainPredictBest() : void { $this->estimator->setLogger(new BlackHole()); - $this->estimator->setBackend(new Serial()); $training = $this->generator->generate(self::TRAIN_SIZE); $testing = $this->generator->generate(self::TEST_SIZE); diff --git a/tests/Transformers/ImageRotatorTest.php b/tests/Transformers/ImageRotatorTest.php index 31297da76..10ccd5bef 100644 --- a/tests/Transformers/ImageRotatorTest.php +++ b/tests/Transformers/ImageRotatorTest.php @@ -12,7 +12,7 @@ * @requires extension gd * @covers \Rubix\ML\Transformers\ImageRotator */ -class RandomizedImageRotatorTest extends TestCase +class ImageRotatorTest extends TestCase { /** * @var ImageRotator diff --git a/tests/Transformers/MaxAbsoluteScalerTest.php b/tests/Transformers/MaxAbsoluteScalerTest.php index 7be4b1c73..0a641333d 100644 --- a/tests/Transformers/MaxAbsoluteScalerTest.php +++ b/tests/Transformers/MaxAbsoluteScalerTest.php @@ -77,9 +77,9 @@ public function fitUpdateTransformReverse() : void $this->assertCount(3, $sample); - $this->assertEqualsWithDelta(0, $sample[0], 1); - $this->assertEqualsWithDelta(0, $sample[1], 1); - $this->assertEqualsWithDelta(0, $sample[2], 1); + $this->assertEqualsWithDelta(0, $sample[0], 2 + 1e-8); + $this->assertEqualsWithDelta(0, $sample[1], 2 + 1e-8); + $this->assertEqualsWithDelta(0, $sample[2], 2 + 1e-8); $dataset->reverseApply($this->transformer); diff --git a/tests/Transformers/RobustStandardizerTest.php b/tests/Transformers/RobustStandardizerTest.php index f1759c691..c706d9bfc 100644 --- a/tests/Transformers/RobustStandardizerTest.php +++ b/tests/Transformers/RobustStandardizerTest.php @@ -86,7 +86,7 @@ public function fitUpdateTransformReverse() : void $dataset->reverseApply($this->transformer); - $this->assertEquals($original, $dataset->sample(0)); + $this->assertEqualsWithDelta($original, $dataset->sample(0), 1e-8); } /**