Skip to content

Commit

Permalink
make sure adapters can access the context of their parent chain
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell committed Oct 3, 2024
1 parent 590648e commit a969381
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/refiners/fluxion/layers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ def replace(
new_module._set_parent(self)
if isinstance(old_module, ContextModule):
old_module._set_parent(old_module_parent)
self._register_provider()

def structural_copy(self: TChain) -> TChain:
"""Copy the structure of the Chain tree.
Expand Down
25 changes: 25 additions & 0 deletions tests/adapters/test_adapter_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts


class ContextAdapter(fl.Chain, Adapter[fl.Chain]):
def __init__(self, target: fl.Chain):
with self.setup_adapter(target):
super().__init__(
fl.Lambda(lambda: 42),
fl.SetContext("foo", "bar"),
)


class ContextChain(fl.Chain):
def init_context(self) -> Contexts:
return {"foo": {"bar": None}}


def test_adapter_can_access_parent_context():
chain = ContextChain(fl.Chain(), fl.UseContext("foo", "bar"))
adaptee = chain.layer("Chain", fl.Chain)
ContextAdapter(adaptee).inject(chain)

assert chain() == 42

0 comments on commit a969381

Please sign in to comment.