diff --git a/capability.go b/capability.go index 473827fa..dd8f1433 100644 --- a/capability.go +++ b/capability.go @@ -550,6 +550,13 @@ func (c Client) AddRef() Client { }) } +// Steal steals the receiver, and returns a new client for the same capability +// owned by the caller. This can be useful for tracking down ownership bugs. +func (c Client) Steal() Client { + defer c.Release() + return c.AddRef() +} + // WeakRef creates a new WeakClient that refers to the same capability // as c. If c is nil or has resolved to null, then WeakRef returns nil. func (c Client) WeakRef() WeakClient { @@ -639,6 +646,12 @@ func (cs ClientSnapshot) AddRef() ClientSnapshot { return cs } +// Steal is like Client.Steal() but for snapshots. +func (cs ClientSnapshot) Steal() ClientSnapshot { + defer cs.Release() + return cs.AddRef() +} + // Release the reference to the hook. func (cs ClientSnapshot) Release() { cs.hook.Release() diff --git a/captable.go b/captable.go index 95a9da6d..844cd31d 100644 --- a/captable.go +++ b/captable.go @@ -52,13 +52,13 @@ func (ct CapTable) Get(ifc Interface) (c Client) { // for the given ID already exists, it will be replaced without // releasing. func (ct CapTable) Set(id CapabilityID, c Client) { - ct.cs[id] = c + ct.cs[id] = c.Steal() } // Add appends a capability to the message's capability table and // returns its ID. It "steals" c's reference: the Message will release // the client when calling Reset. func (ct *CapTable) Add(c Client) CapabilityID { - ct.cs = append(ct.cs, c) + ct.cs = append(ct.cs, c.Steal()) return CapabilityID(ct.Len() - 1) } diff --git a/localpromise.go b/localpromise.go index a478c169..85078fb1 100644 --- a/localpromise.go +++ b/localpromise.go @@ -60,7 +60,7 @@ type localResolver[C ~ClientKind] struct { } func (lf localResolver[C]) Fulfill(c C) { - lf.lp.Fulfill(Client(c)) + lf.lp.Fulfill(Client(c).AddRef()) lf.clientResolver.Fulfill(Client(c)) } diff --git a/pogs/pogs_test.go b/pogs/pogs_test.go index 4d6e09f0..d3a10840 100644 --- a/pogs/pogs_test.go +++ b/pogs/pogs_test.go @@ -62,6 +62,39 @@ type Z struct { AnyCapability capnp.Client } +func (z Z) AddRef() Z { + switch z.Which { + case air.Z_Which_echo: + z.Echo = z.Echo.AddRef() + case air.Z_Which_echoes: + old := z.Echoes + z.Echoes = make([]air.Echo, len(old)) + for i := range old { + z.Echoes[i] = old[i].AddRef() + } + case air.Z_Which_anyCapability: + z.AnyCapability = z.AnyCapability.AddRef() + case air.Z_Which_zvec: + old := z.Zvec + z.Zvec = make([]*Z, len(old)) + for i := range old { + newRef := old[i].AddRef() + z.Zvec[i] = &newRef + } + case air.Z_Which_zvecvec: + old := z.Zvecvec + z.Zvecvec = make([][]*Z, len(old)) + for i := range old { + z.Zvecvec[i] = make([]*Z, len(old[i])) + for j := range old[i] { + newRef := old[i][j].AddRef() + z.Zvecvec[i][j] = &newRef + } + } + } + return z +} + type PlaneBase struct { Name string Homes []air.Airport @@ -241,7 +274,8 @@ func TestInsert(t *testing.T) { t.Errorf("NewRootZ for %s: %v", zpretty.Sprint(test), err) continue } - err = Insert(air.Z_TypeID, capnp.Struct(z), &test) + testCopy := test.AddRef() + err = Insert(air.Z_TypeID, capnp.Struct(z), &testCopy) if err != nil { t.Errorf("Insert(%s) error: %v", zpretty.Sprint(test), err) } @@ -1205,7 +1239,7 @@ func zfill(c air.Z, g *Z) error { c.Grp().SetSecond(g.Grp.Second) } case air.Z_Which_echo: - c.SetEcho(g.Echo) + c.SetEcho(g.Echo.AddRef()) case air.Z_Which_echoes: e, err := c.NewEchoes(int32(len(g.Echoes))) if err != nil { @@ -1215,7 +1249,7 @@ func zfill(c air.Z, g *Z) error { if !ee.IsValid() { continue } - err := e.Set(i, ee) + err := e.Set(i, ee.AddRef()) if err != nil { return err } @@ -1227,7 +1261,7 @@ func zfill(c air.Z, g *Z) error { case air.Z_Which_anyList: return c.SetAnyList(g.AnyList) case air.Z_Which_anyCapability: - return c.SetAnyCapability(g.AnyCapability) + return c.SetAnyCapability(g.AnyCapability.AddRef()) default: return fmt.Errorf("zfill: unknown type: %v", g.Which) }