diff --git a/conn.go b/conn.go index 6d677995..02471a58 100644 --- a/conn.go +++ b/conn.go @@ -35,18 +35,14 @@ type Conn struct { unixFD bool uuid string - names *nameTracker - - serialGen *serialGenerator - - calls *callTracker - - handler Handler + handler Handler + signalHandler SignalHandler + serialGen SerialGenerator + names *nameTracker + calls *callTracker outHandler *outputHandler - signalHandler SignalHandler - eavesdropped chan<- *Message eavesdroppedLck sync.Mutex } @@ -89,23 +85,20 @@ func getSessionBusAddress() (string, error) { } // SessionBusPrivate returns a new private connection to the session bus. -func SessionBusPrivate() (*Conn, error) { +func SessionBusPrivate(opts ...ConnOption) (*Conn, error) { address, err := getSessionBusAddress() if err != nil { return nil, err } - return Dial(address) + return Dial(address, opts...) } // SessionBusPrivate returns a new private connection to the session bus. +// // Deprecated: use SessionBusPrivate with options instead. func SessionBusPrivateHandler(handler Handler, signalHandler SignalHandler) (*Conn, error) { - address, err := getSessionBusAddress() - if err != nil { - return nil, err - } - return Dial(address, WithHandler(handler), WithSignalHandler(signalHandler)) + return SessionBusPrivate(WithHandler(handler), WithSignalHandler(signalHandler)) } // SystemBus returns a shared connection to the system bus, connecting to it if @@ -138,14 +131,15 @@ func SystemBus() (conn *Conn, err error) { } // SystemBusPrivate returns a new private connection to the system bus. -func SystemBusPrivate() (*Conn, error) { - return Dial(getSystemBusPlatformAddress()) +func SystemBusPrivate(opts ...ConnOption) (*Conn, error) { + return Dial(getSystemBusPlatformAddress(), opts...) } // SystemBusPrivateHandler returns a new private connection to the system bus, using the provided handlers. +// // Deprecated: use SystemBusPrivate with options instead. func SystemBusPrivateHandler(handler Handler, signalHandler SignalHandler) (*Conn, error) { - return Dial(getSystemBusPlatformAddress(), WithHandler(handler), WithSignalHandler(signalHandler)) + return SystemBusPrivate(WithHandler(handler), WithSignalHandler(signalHandler)) } // Dial establishes a new private connection to the message bus specified by address. @@ -158,19 +152,16 @@ func Dial(address string, opts ...ConnOption) (*Conn, error) { } // DialHandler establishes a new private connection to the message bus specified by address, using the supplied handlers. +// // Deprecated: use Dial with options instead. func DialHandler(address string, handler Handler, signalHandler SignalHandler) (*Conn, error) { - tr, err := getTransport(address) - if err != nil { - return nil, err - } - return newConn(tr, WithHandler(handler), WithSignalHandler(signalHandler)) + return Dial(address, WithSignalHandler(signalHandler)) } // ConnOption is a connection option. type ConnOption func(conn *Conn) error -// WithHandler overrides default handler. +// WithHandler overrides the default handler. func WithHandler(handler Handler) ConnOption { return func(conn *Conn) error { conn.handler = handler @@ -178,7 +169,7 @@ func WithHandler(handler Handler) ConnOption { } } -// WithSignalHandler overrides default signal handler. +// WithSignalHandler overrides the default signal handler. func WithSignalHandler(handler SignalHandler) ConnOption { return func(conn *Conn) error { conn.signalHandler = handler @@ -186,15 +177,24 @@ func WithSignalHandler(handler SignalHandler) ConnOption { } } +// WithSerialGenerator overrides the default signals generator. +func WithSerialGenerator(gen SerialGenerator) ConnOption { + return func(conn *Conn) error { + conn.serialGen = gen + return nil + } +} + // NewConn creates a new private *Conn from an already established connection. func NewConn(conn io.ReadWriteCloser, opts ...ConnOption) (*Conn, error) { return newConn(genericTransport{conn}, opts...) } // NewConnHandler creates a new private *Conn from an already established connection, using the supplied handlers. +// // Deprecated: use NewConn with options instead. func NewConnHandler(conn io.ReadWriteCloser, handler Handler, signalHandler SignalHandler) (*Conn, error) { - return newConn(genericTransport{conn}, WithHandler(handler), WithSignalHandler(signalHandler)) + return NewConn(genericTransport{conn}, WithHandler(handler), WithSignalHandler(signalHandler)) } // newConn creates a new *Conn from a transport. @@ -213,8 +213,10 @@ func newConn(tr transport, opts ...ConnOption) (*Conn, error) { if conn.signalHandler == nil { conn.signalHandler = NewDefaultSignalHandler() } + if conn.serialGen == nil { + conn.serialGen = newSerialGenerator() + } conn.outHandler = &outputHandler{conn: conn} - conn.serialGen = newSerialGenerator() conn.names = newNameTracker() conn.busObj = conn.Object("org.freedesktop.DBus", "/org/freedesktop/DBus") return conn, nil @@ -262,9 +264,9 @@ func (conn *Conn) Eavesdrop(ch chan<- *Message) { conn.eavesdroppedLck.Unlock() } -// getSerial returns an unused serial. +// GetSerial returns an unused serial. func (conn *Conn) getSerial() uint32 { - return conn.serialGen.getSerial() + return conn.serialGen.GetSerial() } // Hello sends the initial org.freedesktop.DBus.Hello call. This method must be @@ -318,9 +320,9 @@ func (conn *Conn) inWorker() { } switch msg.Type { case TypeError: - conn.serialGen.retireSerial(conn.calls.handleDBusError(msg)) + conn.serialGen.RetireSerial(conn.calls.handleDBusError(msg)) case TypeMethodReply: - conn.serialGen.retireSerial(conn.calls.handleReply(msg)) + conn.serialGen.RetireSerial(conn.calls.handleReply(msg)) case TypeSignal: conn.handleSignal(msg) case TypeMethodCall: @@ -386,9 +388,9 @@ func (conn *Conn) sendMessageAndIfClosed(msg *Message, ifClosed func()) { err := conn.outHandler.sendAndIfClosed(msg, ifClosed) conn.calls.handleSendError(msg, err) if err != nil { - conn.serialGen.retireSerial(msg.serial) + conn.serialGen.RetireSerial(msg.serial) } else if msg.Type != TypeMethodCall { - conn.serialGen.retireSerial(msg.serial) + conn.serialGen.RetireSerial(msg.serial) } } @@ -672,7 +674,7 @@ func newSerialGenerator() *serialGenerator { } } -func (gen *serialGenerator) getSerial() uint32 { +func (gen *serialGenerator) GetSerial() uint32 { gen.lck.Lock() defer gen.lck.Unlock() n := gen.nextSerial @@ -684,7 +686,7 @@ func (gen *serialGenerator) getSerial() uint32 { return n } -func (gen *serialGenerator) retireSerial(serial uint32) { +func (gen *serialGenerator) RetireSerial(serial uint32) { gen.lck.Lock() defer gen.lck.Unlock() delete(gen.serialUsed, serial) diff --git a/default_handler.go b/default_handler.go index 54b484c2..81dbcc7e 100644 --- a/default_handler.go +++ b/default_handler.go @@ -21,6 +21,8 @@ func newIntrospectIntf(h *defaultHandler) *exportedIntf { //NewDefaultHandler returns an instance of the default //call handler. This is useful if you want to implement only //one of the two handlers but not both. +// +// Deprecated: this is the default value, don't use it, it will be unexported. func NewDefaultHandler() *defaultHandler { h := &defaultHandler{ objects: make(map[ObjectPath]*exportedObj), @@ -229,6 +231,8 @@ func (obj *exportedIntf) isFallbackInterface() bool { //NewDefaultSignalHandler returns an instance of the default //signal handler. This is useful if you want to implement only //one of the two handlers but not both. +// +// Deprecated: this is the default value, don't use it, it will be unexported. func NewDefaultSignalHandler() *defaultSignalHandler { return &defaultSignalHandler{ closeChan: make(chan struct{}), diff --git a/server_interfaces.go b/server_interfaces.go index 091948ae..01166f0b 100644 --- a/server_interfaces.go +++ b/server_interfaces.go @@ -87,3 +87,13 @@ type SignalHandler interface { type DBusError interface { DBusError() (string, []interface{}) } + +// SerialGenerator is responsible for serials generation. +// +// Different approaches for the serial generation can be used, +// maintaining a map guarded with a mutex (the standard way) or +// simply increment an atomic counter. +type SerialGenerator interface { + GetSerial() uint32 + RetireSerial(serial uint32) +} diff --git a/server_interfaces_test.go b/server_interfaces_test.go index 56231160..ca27b786 100644 --- a/server_interfaces_test.go +++ b/server_interfaces_test.go @@ -3,6 +3,7 @@ package dbus import ( "fmt" "sync" + "sync/atomic" "testing" "time" ) @@ -13,6 +14,8 @@ type tester struct { subSigsMu sync.Mutex subSigs map[string]map[string]struct{} + + serial uint32 } type intro struct { @@ -184,6 +187,12 @@ func (t *tester) Name() string { return t.conn.Names()[0] } +func (t *tester) GetSerial() uint32 { + return atomic.AddUint32(&t.serial, 1) +} + +func (t *tester) RetireSerial(serial uint32) {} + type intro_fn func() string func (intro intro_fn) Call(args ...interface{}) ([]interface{}, error) { @@ -211,7 +220,11 @@ func newTester() (*tester, error) { sigs: make(chan *Signal), subSigs: make(map[string]map[string]struct{}), } - conn, err := SessionBusPrivateHandler(tester, tester) + conn, err := SessionBusPrivate( + WithHandler(tester), + WithSignalHandler(tester), + WithSerialGenerator(tester), + ) if err != nil { return nil, err } @@ -430,7 +443,6 @@ func (x *X) Method1() *Error { return nil } - func TestRaceInExport(t *testing.T) { const ( dbusPath = "/org/example/godbus/test1"