From ae8deee188f38582baf3e340ca4879e651bf1858 Mon Sep 17 00:00:00 2001 From: Chris Lloyd Date: Sun, 26 May 2024 16:08:43 +0100 Subject: [PATCH] Create a 'drop' parameter --- docs/transformers/one-hot-encoder.md | 4 ++- src/Transformers/OneHotEncoder.php | 20 ++++++++++-- tests/Transformers/OneHotEncoderTest.php | 40 ++++++++++++++++++++++-- 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/docs/transformers/one-hot-encoder.md b/docs/transformers/one-hot-encoder.md index 6530555d6..39ef15480 100644 --- a/docs/transformers/one-hot-encoder.md +++ b/docs/transformers/one-hot-encoder.md @@ -8,7 +8,9 @@ The One Hot Encoder takes a categorical feature column and produces an n-dimensi **Data Type Compatibility:** Categorical ## Parameters -This transformer does not have any parameters. +| # | Name | Default | Type | Description | +|---|------|---------|----------------|-------------| +| 1 | drop | [] | array\|string | The list of categories to drop (ignore) during categorization | ## Example ```php diff --git a/src/Transformers/OneHotEncoder.php b/src/Transformers/OneHotEncoder.php index 99373c935..5634cf58d 100644 --- a/src/Transformers/OneHotEncoder.php +++ b/src/Transformers/OneHotEncoder.php @@ -42,11 +42,25 @@ class OneHotEncoder implements Transformer, Stateful, Persistable protected ?array $categories = null; /** - * Return the data types that this transformer is compatible with. + * The categories that should be ignored * - * @internal + * @var array + */ + protected array $drop = []; + + /** + * @param string|array $drop The categories to drop during encoding + */ + public function __construct($drop = []) + { + $this->drop = is_array($drop) ? $drop : [$drop]; + } + + /** + * Return the data types that this transformer is compatible with. * * @return list<\Rubix\ML\DataType> + * @internal */ public function compatibility() : array { @@ -88,6 +102,8 @@ public function fit(Dataset $dataset) : void if ($type->isCategorical()) { $values = $dataset->feature($column); + $values = array_diff($values, $this->drop); + $categories = array_values(array_unique($values)); /** @var int[] $offsets */ diff --git a/tests/Transformers/OneHotEncoderTest.php b/tests/Transformers/OneHotEncoderTest.php index 583afb980..b6851d213 100644 --- a/tests/Transformers/OneHotEncoderTest.php +++ b/tests/Transformers/OneHotEncoderTest.php @@ -2,11 +2,11 @@ namespace Rubix\ML\Tests\Transformers; +use PHPUnit\Framework\TestCase; use Rubix\ML\Datasets\Unlabeled; +use Rubix\ML\Transformers\OneHotEncoder; use Rubix\ML\Transformers\Stateful; use Rubix\ML\Transformers\Transformer; -use Rubix\ML\Transformers\OneHotEncoder; -use PHPUnit\Framework\TestCase; /** * @group Transformers @@ -70,4 +70,40 @@ public function fitTransform() : void $this->assertEquals($expected, $dataset->samples()); } + + /** + * @test + */ + public function fitTransformNone() : void + { + $dataset = new Unlabeled([ + ['nice', 'furry', 'friendly'], + ['mean', 'furry', 'loner'], + ['nice', 'rough', 'friendly'], + ['mean', 'rough', 'friendly'], + ]); + + $this->transformer = new OneHotEncoder('furry'); + + $this->transformer->fit($dataset); + + $this->assertTrue($this->transformer->fitted()); + + $categories = $this->transformer->categories(); + + $this->assertIsArray($categories); + $this->assertCount(3, $categories); + $this->assertContainsOnly('array', $categories); + + $dataset->apply($this->transformer); + + $expected = [ + [1, 0, 0, 1, 0], + [0, 1, 0, 0, 1], + [1, 0, 1, 1, 0], + [0, 1, 1, 1, 0], + ]; + + $this->assertEquals($expected, $dataset->samples()); + } }