Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Steal() method to Client and ClientSnapshot #522

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions capability.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,13 @@ func (c Client) AddRef() Client {
})
}

// Steal steals the receiver, and returns a new client for the same capability
// owned by the caller. This can be useful for tracking down ownership bugs.
func (c Client) Steal() Client {
defer c.Release()
return c.AddRef()
}

// WeakRef creates a new WeakClient that refers to the same capability
// as c. If c is nil or has resolved to null, then WeakRef returns nil.
func (c Client) WeakRef() WeakClient {
Expand Down Expand Up @@ -639,6 +646,12 @@ func (cs ClientSnapshot) AddRef() ClientSnapshot {
return cs
}

// Steal is like Client.Steal() but for snapshots.
func (cs ClientSnapshot) Steal() ClientSnapshot {
defer cs.Release()
return cs.AddRef()
}

// Release the reference to the hook.
func (cs ClientSnapshot) Release() {
cs.hook.Release()
Expand Down
4 changes: 2 additions & 2 deletions captable.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ func (ct CapTable) Get(ifc Interface) (c Client) {
// for the given ID already exists, it will be replaced without
// releasing.
func (ct CapTable) Set(id CapabilityID, c Client) {
ct.cs[id] = c
ct.cs[id] = c.Steal()
}

// Add appends a capability to the message's capability table and
// returns its ID. It "steals" c's reference: the Message will release
// the client when calling Reset.
func (ct *CapTable) Add(c Client) CapabilityID {
ct.cs = append(ct.cs, c)
ct.cs = append(ct.cs, c.Steal())
return CapabilityID(ct.Len() - 1)
}
2 changes: 1 addition & 1 deletion localpromise.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type localResolver[C ~ClientKind] struct {
}

func (lf localResolver[C]) Fulfill(c C) {
lf.lp.Fulfill(Client(c))
lf.lp.Fulfill(Client(c).AddRef())
lf.clientResolver.Fulfill(Client(c))
}

Expand Down
42 changes: 38 additions & 4 deletions pogs/pogs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,39 @@ type Z struct {
AnyCapability capnp.Client
}

func (z Z) AddRef() Z {
switch z.Which {
case air.Z_Which_echo:
z.Echo = z.Echo.AddRef()
case air.Z_Which_echoes:
old := z.Echoes
z.Echoes = make([]air.Echo, len(old))
for i := range old {
z.Echoes[i] = old[i].AddRef()
}
case air.Z_Which_anyCapability:
z.AnyCapability = z.AnyCapability.AddRef()
case air.Z_Which_zvec:
old := z.Zvec
z.Zvec = make([]*Z, len(old))
for i := range old {
newRef := old[i].AddRef()
z.Zvec[i] = &newRef
}
case air.Z_Which_zvecvec:
old := z.Zvecvec
z.Zvecvec = make([][]*Z, len(old))
for i := range old {
z.Zvecvec[i] = make([]*Z, len(old[i]))
for j := range old[i] {
newRef := old[i][j].AddRef()
z.Zvecvec[i][j] = &newRef
}
}
}
return z
}

type PlaneBase struct {
Name string
Homes []air.Airport
Expand Down Expand Up @@ -241,7 +274,8 @@ func TestInsert(t *testing.T) {
t.Errorf("NewRootZ for %s: %v", zpretty.Sprint(test), err)
continue
}
err = Insert(air.Z_TypeID, capnp.Struct(z), &test)
testCopy := test.AddRef()
err = Insert(air.Z_TypeID, capnp.Struct(z), &testCopy)
if err != nil {
t.Errorf("Insert(%s) error: %v", zpretty.Sprint(test), err)
}
Expand Down Expand Up @@ -1205,7 +1239,7 @@ func zfill(c air.Z, g *Z) error {
c.Grp().SetSecond(g.Grp.Second)
}
case air.Z_Which_echo:
c.SetEcho(g.Echo)
c.SetEcho(g.Echo.AddRef())
case air.Z_Which_echoes:
e, err := c.NewEchoes(int32(len(g.Echoes)))
if err != nil {
Expand All @@ -1215,7 +1249,7 @@ func zfill(c air.Z, g *Z) error {
if !ee.IsValid() {
continue
}
err := e.Set(i, ee)
err := e.Set(i, ee.AddRef())
if err != nil {
return err
}
Expand All @@ -1227,7 +1261,7 @@ func zfill(c air.Z, g *Z) error {
case air.Z_Which_anyList:
return c.SetAnyList(g.AnyList)
case air.Z_Which_anyCapability:
return c.SetAnyCapability(g.AnyCapability)
return c.SetAnyCapability(g.AnyCapability.AddRef())
default:
return fmt.Errorf("zfill: unknown type: %v", g.Which)
}
Expand Down