Skip to content

Commit

Permalink
Create a 'drop' parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
27pchrisl committed May 27, 2024
1 parent 6e81fdf commit ae8deee
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 5 deletions.
4 changes: 3 additions & 1 deletion docs/transformers/one-hot-encoder.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ The One Hot Encoder takes a categorical feature column and produces an n-dimensi
**Data Type Compatibility:** Categorical

## Parameters
This transformer does not have any parameters.
| # | Name | Default | Type | Description |
|---|------|---------|----------------|-------------|
| 1 | drop | [] | array\|string | The list of categories to drop (ignore) during categorization |

## Example
```php
Expand Down
20 changes: 18 additions & 2 deletions src/Transformers/OneHotEncoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,25 @@ class OneHotEncoder implements Transformer, Stateful, Persistable
protected ?array $categories = null;

/**
* Return the data types that this transformer is compatible with.
* The categories that should be ignored
*
* @internal
* @var array<string>
*/
protected array $drop = [];

/**
* @param string|array $drop The categories to drop during encoding
*/
public function __construct($drop = [])
{
$this->drop = is_array($drop) ? $drop : [$drop];
}

/**
* Return the data types that this transformer is compatible with.
*
* @return list<\Rubix\ML\DataType>
* @internal
*/
public function compatibility() : array
{
Expand Down Expand Up @@ -88,6 +102,8 @@ public function fit(Dataset $dataset) : void
if ($type->isCategorical()) {
$values = $dataset->feature($column);

$values = array_diff($values, $this->drop);

$categories = array_values(array_unique($values));

/** @var int[] $offsets */
Expand Down
40 changes: 38 additions & 2 deletions tests/Transformers/OneHotEncoderTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

namespace Rubix\ML\Tests\Transformers;

use PHPUnit\Framework\TestCase;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Transformers\OneHotEncoder;
use Rubix\ML\Transformers\Stateful;
use Rubix\ML\Transformers\Transformer;
use Rubix\ML\Transformers\OneHotEncoder;
use PHPUnit\Framework\TestCase;

/**
* @group Transformers
Expand Down Expand Up @@ -70,4 +70,40 @@ public function fitTransform() : void

$this->assertEquals($expected, $dataset->samples());
}

/**
* @test
*/
public function fitTransformNone() : void
{
$dataset = new Unlabeled([
['nice', 'furry', 'friendly'],
['mean', 'furry', 'loner'],
['nice', 'rough', 'friendly'],
['mean', 'rough', 'friendly'],
]);

$this->transformer = new OneHotEncoder('furry');

$this->transformer->fit($dataset);

$this->assertTrue($this->transformer->fitted());

$categories = $this->transformer->categories();

$this->assertIsArray($categories);
$this->assertCount(3, $categories);
$this->assertContainsOnly('array', $categories);

$dataset->apply($this->transformer);

$expected = [
[1, 0, 0, 1, 0],
[0, 1, 0, 0, 1],
[1, 0, 1, 1, 0],
[0, 1, 1, 1, 0],
];

$this->assertEquals($expected, $dataset->samples());
}
}

0 comments on commit ae8deee

Please sign in to comment.