Skip to content

Commit

Permalink
Adding one-hot to binary helper function (#456)
Browse files Browse the repository at this point in the history
* Adding one-hot to binary helper function

* Addressed comments for one-hot to binary helper function
  • Loading branch information
vaaniarora authored Dec 12, 2024
1 parent e63b71d commit 4c72f2d
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 2 deletions.
1 change: 1 addition & 0 deletions pyrtl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .helperfuncs import find_and_print_loop
from .helperfuncs import wire_struct
from .helperfuncs import wire_matrix
from .helperfuncs import one_hot_to_binary

from .corecircuits import and_all_bits
from .corecircuits import or_all_bits
Expand Down
33 changes: 32 additions & 1 deletion pyrtl/helperfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .core import working_block, _NameIndexer, _get_debug_mode, Block
from .pyrtlexceptions import PyrtlError, PyrtlInternalError
from .wire import WireVector, Input, Output, Const, Register, WrappedWireVector
from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list
from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list, select

# -----------------------------------------------------------------
# ___ __ ___ __ __
Expand Down Expand Up @@ -1683,3 +1683,34 @@ def __len__(self):
return len(self._components)

return _WireMatrix


def one_hot_to_binary(w) -> WireVector:
'''Takes a one-hot input and returns the bit position of the high bit in binary.
:param w: WireVector or a WireVector-like object or something that can be converted
into a Const (in accordance with the :py:func:`as_wires()` required input). Example
inputs: 0b0010, 64, 0b01.
:return: The bit position of the high bit in binary as a WireVector.
If the input contains multiple 1s, the bit position of the first 1 will
be returned. If the input contains no 1s, 0 will be returned.
Examples::
one_hot_to_binary(0b0010) # returns 1
one_hot_to_binary(64) # returns 6
one_hot_to_binary(0b1100) # returns 2, the bit position of the first 1
one_hot_to_binary(0) # returns 0
'''

w = as_wires(w)

pos = 0 # Bit position of the first 1
already_found = as_wires(False) # True if first 1 already found, False otherwise

for i in range(len(w)):
pos = select(w[i] & ~already_found, i, pos)
already_found = already_found | w[i]

return pos
44 changes: 43 additions & 1 deletion tests/test_helperfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import pyrtl.helperfuncs
from pyrtl.rtllib import testingutils as utils


# ---------------------------------------------------------------


class TestWireVectorList(unittest.TestCase):
def setUp(self):
pass
Expand Down Expand Up @@ -1772,5 +1772,47 @@ def test_byte_matrix_input_concatenate(self):
self.assertEqual(sim.inspect('byte_matrix[0].low'), 0xB)


class TestOneHotToBinary(unittest.TestCase):
def setUp(self):
pyrtl.reset_working_block()

def test_simple_onehot(self):
i = pyrtl.Input(bitwidth=8, name='i')
o = pyrtl.Output(bitwidth=3, name='o')
o <<= pyrtl.one_hot_to_binary(i)

sim = pyrtl.Simulation()
sim.step({i: 0b00000001})
self.assertEqual(sim.inspect('o'), 0)
sim.step({i: 0b10000000})
self.assertEqual(sim.inspect('o'), 7)
sim.step({i: 32})
self.assertEqual(sim.inspect('o'), 5)
sim.step({i: 16})
self.assertEqual(sim.inspect('o'), 4)

def test_multiple_ones(self):
i = pyrtl.Input(bitwidth=8, name='i')
o = pyrtl.Output(bitwidth=3, name='o')
o <<= pyrtl.one_hot_to_binary(i)

sim = pyrtl.Simulation()
sim.step({i: 0b00000101})
self.assertEqual(sim.inspect('o'), 0)
sim.step({i: 0b11000000})
self.assertEqual(sim.inspect('o'), 6)
sim.step({i: 0b10010010})
self.assertEqual(sim.inspect('o'), 1)

def test_no_ones(self):
i = pyrtl.Input(bitwidth=8, name='i')
o = pyrtl.Output(bitwidth=3, name='o')
o <<= pyrtl.one_hot_to_binary(i)

sim = pyrtl.Simulation()
sim.step({i: 0b00000000})
self.assertEqual(sim.inspect('o'), 0)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4c72f2d

Please sign in to comment.