From 3b415a4be7c1a31fc329e42d779332bedf1046ad Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Sat, 4 Jan 2020 17:12:25 +0900 Subject: [PATCH 1/3] implement driver.DriverContext --- fakedb_go110_test.go | 101 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 fakedb_go110_test.go diff --git a/fakedb_go110_test.go b/fakedb_go110_test.go new file mode 100644 index 0000000..4462a80 --- /dev/null +++ b/fakedb_go110_test.go @@ -0,0 +1,101 @@ +// +build go1.10 + +package proxy + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" +) + +type fakeDriverCtx fakeDriver +type fakeConnector struct { + driver *fakeDriverCtx + opt *fakeConnOption + db *fakeDB +} + +var fdriverctx = &fakeDriver{} +var _ driver.DriverContext = &fakeDriverCtx{} +var _ driver.Connector = &fakeConnector{} + +func init() { + sql.Register("fakedbctx", fdriverctx) +} + +func (d *fakeDriverCtx) Open(name string) (driver.Conn, error) { + return nil, errors.New("not implemented") +} + +func (d *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { + var opt fakeConnOption + err := json.Unmarshal([]byte(name), &opt) + if err != nil { + return nil, err + } + + // validate options + switch opt.ConnType { + case "", "fakeConn", "fakeConnExt", "fakeConnCtx": + // validation OK + default: + return nil, errors.New("known ConnType") + } + + d.mu.Lock() + defer d.mu.Unlock() + db, ok := d.dbs[opt.Name] + if !ok { + db = &fakeDB{ + log: &bytes.Buffer{}, + } + if d.dbs == nil { + d.dbs = make(map[string]*fakeDB) + } + d.dbs[name] = db + } + + return &fakeConnector{ + driver: d, + opt: &opt, + db: db, + }, nil +} + +func (d *fakeDriverCtx) DB(name string) *fakeDB { + d.mu.Lock() + defer d.mu.Unlock() + return d.dbs[name] +} + +func (c *fakeConnector) Connect(ctx context.Context) (driver.Conn, error) { + var conn driver.Conn + switch c.opt.ConnType { + case "", "fakeConn": + conn = &fakeConn{ + db: c.db, + opt: c.opt, + } + case "fakeConnExt": + conn = &fakeConnExt{ + db: c.db, + opt: c.opt, + } + case "fakeConnCtx": + conn = &fakeConnCtx{ + db: c.db, + opt: c.opt, + } + default: + return nil, errors.New("known ConnType") + } + + return conn, nil +} + +func (c *fakeConnector) Driver() driver.Driver { + return c.driver +} From 0311edabfe4558bb2f283bac0faa53130c47ef5d Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Sat, 4 Jan 2020 17:18:40 +0900 Subject: [PATCH 2/3] add hooks.go --- hooks.go | 1265 +++++++++++++++++++++++++++++++++++++++++++++++++ hooks_test.go | 334 +++++++++++++ proxy.go | 1257 ------------------------------------------------ proxy_test.go | 328 ------------- 4 files changed, 1599 insertions(+), 1585 deletions(-) create mode 100644 hooks.go create mode 100644 hooks_test.go diff --git a/hooks.go b/hooks.go new file mode 100644 index 0000000..fc13fe3 --- /dev/null +++ b/hooks.go @@ -0,0 +1,1265 @@ +// a proxy package is a proxy driver for database/sql. + +package proxy + +import ( + "context" + "database/sql/driver" + "errors" +) + +// hooks is callback functions for the proxy. +// it is private because it doesn't guarantee backward compatibility. +type hooks interface { + prePing(c context.Context, conn *Conn) (interface{}, error) + ping(c context.Context, ctx interface{}, conn *Conn) error + postPing(c context.Context, ctx interface{}, conn *Conn, err error) error + preOpen(c context.Context, name string) (interface{}, error) + open(c context.Context, ctx interface{}, conn *Conn) error + postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error + preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) + exec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error + postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error + preQuery(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) + query(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows) error + postQuery(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error + preBegin(c context.Context, conn *Conn) (interface{}, error) + begin(c context.Context, ctx interface{}, conn *Conn) error + postBegin(c context.Context, ctx interface{}, conn *Conn, err error) error + preCommit(c context.Context, tx *Tx) (interface{}, error) + commit(c context.Context, ctx interface{}, tx *Tx) error + postCommit(c context.Context, ctx interface{}, tx *Tx, err error) error + preRollback(c context.Context, tx *Tx) (interface{}, error) + rollback(c context.Context, ctx interface{}, tx *Tx) error + postRollback(c context.Context, ctx interface{}, tx *Tx, err error) error + preClose(c context.Context, conn *Conn) (interface{}, error) + close(c context.Context, ctx interface{}, conn *Conn) error + postClose(c context.Context, ctx interface{}, conn *Conn, err error) error + preResetSession(c context.Context, conn *Conn) (interface{}, error) + resetSession(c context.Context, ctx interface{}, conn *Conn) error + postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error +} + +// HooksContext is callback functions with context.Context for the proxy. +type HooksContext struct { + // PrePing is a callback that gets called prior to calling + // `Conn.Ping`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Conn.Ping` and `Hooks.Ping` methods + // are not called. + // + // The first return value is passed to both `Hooks.Ping` and + // `Hooks.PostPing` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Conn.Ping` method and `Hooks.Ping` + // methods are not called. + PrePing func(c context.Context, conn *Conn) (interface{}, error) + + // Ping is called after the underlying driver's `Conn.Exec` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PrePing` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Conn.Ping` method. + Ping func(c context.Context, ctx interface{}, conn *Conn) error + + // PostPing is a callback that gets called at the end of + // the call to `Conn.Ping`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PrePing` method, and may be nil. + PostPing func(c context.Context, ctx interface{}, conn *Conn, err error) error + + // PreOpen is a callback that gets called before any + // attempt to open the sql connection is made, and is ALWAYS + // called. + // + // The first return value is passed to both `Hooks.Open` and + // `Hooks.PostOpen` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Driver.Open` method and `Hooks.Open` + // methods are not called. + PreOpen func(c context.Context, name string) (interface{}, error) + + // Open is called after the underlying driver's `Driver.Open` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreOpen` method, and may be nil. + // + // If this callback returns an error, then the `conn` object is + // closed by calling the `Close` method, and the error from this + // callback is returned by the `db.Open` method. + Open func(c context.Context, ctx interface{}, conn *Conn) error + + // PostOpen is a callback that gets called at the end of + // the call to `db.Open(). It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreOpen` method, and may be nil. + PostOpen func(c context.Context, ctx interface{}, conn *Conn, err error) error + + // PreExec is a callback that gets called prior to calling + // `Stmt.Exec`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Stmt.Exec` and `Hooks.Exec` methods + // are not called. + // + // The first return value is passed to both `Hooks.Exec` and + // `Hooks.PostExec` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Driver.Exec` method and `Hooks.Exec` + // methods are not called. + PreExec func(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) + + // Exec is called after the underlying driver's `Driver.Exec` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreExec` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Stmt.Exec` method. + Exec func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error + + // PostExec is a callback that gets called at the end of + // the call to `Stmt.Exec`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreExec` method, and may be nil. + PostExec func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error + + // PreQuery is a callback that gets called prior to calling + // `Stmt.Query`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Stmt.Query` and `Hooks.Query` methods + // are not called. + // + // The first return value is passed to both `Hooks.Query` and + // `Hooks.PostQuery` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Stmt.Query` method and `Hooks.Query` + // methods are not called. + PreQuery func(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) + + // Query is called after the underlying driver's `Stmt.Query` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreQuery` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Stmt.Query` method. + Query func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows) error + + // PostQuery is a callback that gets called at the end of + // the call to `Stmt.Query`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreQuery` method, and may be nil. + PostQuery func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error + + // PreBegin is a callback that gets called prior to calling + // `Stmt.Begin`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Conn.Begin` and `Hooks.Begin` methods + // are not called. + // + // The first return value is passed to both `Hooks.Begin` and + // `Hooks.PostBegin` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Conn.Begin` method and `Hooks.Begin` + // methods are not called. + PreBegin func(c context.Context, conn *Conn) (interface{}, error) + + // Begin is called after the underlying driver's `Conn.Begin` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreBegin` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Conn.Begin` method. + Begin func(c context.Context, ctx interface{}, conn *Conn) error + + // PostBegin is a callback that gets called at the end of + // the call to `Conn.Begin`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreBegin` method, and may be nil. + PostBegin func(c context.Context, ctx interface{}, conn *Conn, err error) error + + // PreCommit is a callback that gets called prior to calling + // `Tx.Commit`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Tx.Commit` and `Hooks.Commit` methods + // are not called. + // + // The first return value is passed to both `Hooks.Commit` and + // `Hooks.PostCommit` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Tx.Commit` method and `Hooks.Commit` + // methods are not called. + PreCommit func(c context.Context, tx *Tx) (interface{}, error) + + // Commit is called after the underlying driver's `Tx.Commit` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreCommit` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Tx.Commit` method. + Commit func(c context.Context, ctx interface{}, tx *Tx) error + + // PostCommit is a callback that gets called at the end of + // the call to `Tx.Commit`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreCommit` method, and may be nil. + PostCommit func(c context.Context, ctx interface{}, tx *Tx, err error) error + + // PreRollback is a callback that gets called prior to calling + // `Tx.Rollback`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Tx.Rollback` and `Hooks.Rollback` methods + // are not called. + // + // The first return value is passed to both `Hooks.Rollback` and + // `Hooks.PostRollback` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Tx.Rollback` method and `Hooks.Rollback` + PreRollback func(c context.Context, tx *Tx) (interface{}, error) + + // Rollback is called after the underlying driver's `Tx.Rollback` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreRollback` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Tx.Rollback` method. + Rollback func(c context.Context, ctx interface{}, tx *Tx) error + + // PostRollback is a callback that gets called at the end of + // the call to `Tx.Rollback`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreRollback` method, and may be nil. + PostRollback func(c context.Context, ctx interface{}, tx *Tx, err error) error + + // PreClose is a callback that gets called prior to calling + // `Conn.Close`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Conn.Close` and `Hooks.Close` methods + // are not called. + // + // The first return value is passed to both `Hooks.Close` and + // `Hooks.PostClose` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Conn.Close` method and `Hooks.Close` + // methods are not called. + PreClose func(c context.Context, conn *Conn) (interface{}, error) + + // Close is called after the underlying driver's `Conn.Close` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreClose` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Conn.Close` method. + Close func(c context.Context, ctx interface{}, conn *Conn) error + + // PostClose is a callback that gets called at the end of + // the call to `Conn.Close`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreClose` method, and may be nil. + PostClose func(c context.Context, ctx interface{}, conn *Conn, err error) error + + // PreResetSession is a callback that gets called prior to calling + // `Conn.ResetSession`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Conn.ResetSession` and `Hooks.ResetSession` methods + // are not called. + // + // The first return value is passed to both `Hooks.ResetSession` and + // `Hooks.PostResetSession` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Conn.ResetSession` method and `Hooks.ResetSession` + // methods are not called. + PreResetSession func(c context.Context, conn *Conn) (interface{}, error) + + // ResetSession is called after the underlying driver's `Conn.ResetSession` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreResetSession` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Conn.ResetSession` method. + ResetSession func(c context.Context, ctx interface{}, conn *Conn) error + + // PostResetSession is a callback that gets called at the end of + // the call to `Conn.ResetSession`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreResetSession` method, and may be nil. + PostResetSession func(c context.Context, ctx interface{}, conn *Conn, err error) error +} + +func (h *HooksContext) prePing(c context.Context, conn *Conn) (interface{}, error) { + if h == nil || h.PrePing == nil { + return nil, nil + } + return h.PrePing(c, conn) +} + +func (h *HooksContext) ping(c context.Context, ctx interface{}, conn *Conn) error { + if h == nil || h.Ping == nil { + return nil + } + return h.Ping(c, ctx, conn) +} + +func (h *HooksContext) postPing(c context.Context, ctx interface{}, conn *Conn, err error) error { + if h == nil || h.PostPing == nil { + return nil + } + return h.PostPing(c, ctx, conn, err) +} + +func (h *HooksContext) preOpen(c context.Context, name string) (interface{}, error) { + if h == nil || h.PreOpen == nil { + return nil, nil + } + return h.PreOpen(c, name) +} + +func (h *HooksContext) open(c context.Context, ctx interface{}, conn *Conn) error { + if h == nil || h.Open == nil { + return nil + } + return h.Open(c, ctx, conn) +} + +func (h *HooksContext) postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error { + if h == nil || h.PostOpen == nil { + return nil + } + return h.PostOpen(c, ctx, conn, err) +} + +func (h *HooksContext) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { + if h == nil || h.PreExec == nil { + return nil, nil + } + return h.PreExec(c, stmt, args) +} + +func (h *HooksContext) exec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error { + if h == nil || h.Exec == nil { + return nil + } + return h.Exec(c, ctx, stmt, args, result) +} + +func (h *HooksContext) postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error { + if h == nil || h.PostExec == nil { + return nil + } + return h.PostExec(c, ctx, stmt, args, result, err) +} + +func (h *HooksContext) preQuery(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { + if h == nil || h.PreQuery == nil { + return nil, nil + } + return h.PreQuery(c, stmt, args) +} + +func (h *HooksContext) query(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows) error { + if h == nil || h.Query == nil { + return nil + } + return h.Query(c, ctx, stmt, args, rows) +} + +func (h *HooksContext) postQuery(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error { + if h == nil || h.PostQuery == nil { + return nil + } + return h.PostQuery(c, ctx, stmt, args, rows, err) +} + +func (h *HooksContext) preBegin(c context.Context, conn *Conn) (interface{}, error) { + if h == nil || h.PreBegin == nil { + return nil, nil + } + return h.PreBegin(c, conn) +} + +func (h *HooksContext) begin(c context.Context, ctx interface{}, conn *Conn) error { + if h == nil || h.Open == nil { + return nil + } + return h.Begin(c, ctx, conn) +} + +func (h *HooksContext) postBegin(c context.Context, ctx interface{}, conn *Conn, err error) error { + if h == nil || h.PostBegin == nil { + return nil + } + return h.PostBegin(c, ctx, conn, err) +} + +func (h *HooksContext) preCommit(c context.Context, tx *Tx) (interface{}, error) { + if h == nil || h.PreCommit == nil { + return nil, nil + } + return h.PreCommit(c, tx) +} + +func (h *HooksContext) commit(c context.Context, ctx interface{}, tx *Tx) error { + if h == nil || h.Commit == nil { + return nil + } + return h.Commit(c, ctx, tx) +} + +func (h *HooksContext) postCommit(c context.Context, ctx interface{}, tx *Tx, err error) error { + if h == nil || h.PostCommit == nil { + return nil + } + return h.PostCommit(c, ctx, tx, err) +} + +func (h *HooksContext) preRollback(c context.Context, tx *Tx) (interface{}, error) { + if h == nil || h.PreRollback == nil { + return nil, nil + } + return h.PreRollback(c, tx) +} + +func (h *HooksContext) rollback(c context.Context, ctx interface{}, tx *Tx) error { + if h == nil || h.Rollback == nil { + return nil + } + return h.Rollback(c, ctx, tx) +} + +func (h *HooksContext) postRollback(c context.Context, ctx interface{}, tx *Tx, err error) error { + if h == nil || h.PostRollback == nil { + return nil + } + return h.PostRollback(c, ctx, tx, err) +} + +func (h *HooksContext) preClose(c context.Context, conn *Conn) (interface{}, error) { + if h == nil || h.PreClose == nil { + return nil, nil + } + return h.PreClose(c, conn) +} + +func (h *HooksContext) close(c context.Context, ctx interface{}, conn *Conn) error { + if h == nil || h.Close == nil { + return nil + } + return h.Close(c, ctx, conn) +} + +func (h *HooksContext) postClose(c context.Context, ctx interface{}, conn *Conn, err error) error { + if h == nil || h.PostClose == nil { + return nil + } + return h.PostClose(c, ctx, conn, err) +} + +func (h *HooksContext) preResetSession(c context.Context, conn *Conn) (interface{}, error) { + if h == nil || h.PreResetSession == nil { + return nil, nil + } + return h.PreResetSession(c, conn) +} + +func (h *HooksContext) resetSession(c context.Context, ctx interface{}, conn *Conn) error { + if h == nil || h.ResetSession == nil { + return nil + } + return h.ResetSession(c, ctx, conn) +} + +func (h *HooksContext) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error { + if h == nil || h.PostResetSession == nil { + return nil + } + return h.PostResetSession(c, ctx, conn, err) +} + +// Hooks is callback functions for the proxy. +// Deprecated: You should use HooksContext instead. +type Hooks struct { + // PrePing is a callback that gets called prior to calling + // `Conn.Ping`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Conn.Ping` and `Hooks.Ping` methods + // are not called. + // + // The first return value is passed to both `Hooks.Ping` and + // `Hooks.PostPing` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Conn.Ping` method and `Hooks.Ping` + // methods are not called. + PrePing func(conn *Conn) (interface{}, error) + + // Ping is called after the underlying driver's `Conn.Exec` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PrePing` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Conn.Ping` method. + Ping func(ctx interface{}, conn *Conn) error + + // PostPing is a callback that gets called at the end of + // the call to `Conn.Ping`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PrePing` method, and may be nil. + PostPing func(ctx interface{}, conn *Conn, err error) error + + // PreOpen is a callback that gets called before any + // attempt to open the sql connection is made, and is ALWAYS + // called. + // + // The first return value is passed to both `Hooks.Open` and + // `Hooks.PostOpen` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Driver.Open` method and `Hooks.Open` + // methods are not called. + PreOpen func(name string) (interface{}, error) + + // Open is called after the underlying driver's `Driver.Open` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreOpen` method, and may be nil. + // + // If this callback returns an error, then the `conn` object is + // closed by calling the `Close` method, and the error from this + // callback is returned by the `db.Open` method. + Open func(ctx interface{}, conn *Conn) error + + // PostOpen is a callback that gets called at the end of + // the call to `db.Open(). It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreOpen` method, and may be nil. + PostOpen func(ctx interface{}, conn *Conn) error + + // PreExec is a callback that gets called prior to calling + // `Stmt.Exec`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Stmt.Exec` and `Hooks.Exec` methods + // are not called. + // + // The first return value is passed to both `Hooks.Exec` and + // `Hooks.PostExec` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Driver.Exec` method and `Hooks.Exec` + // methods are not called. + PreExec func(stmt *Stmt, args []driver.Value) (interface{}, error) + + // Exec is called after the underlying driver's `Driver.Exec` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreExec` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Stmt.Exec` method. + Exec func(ctx interface{}, stmt *Stmt, args []driver.Value, result driver.Result) error + + // PostExec is a callback that gets called at the end of + // the call to `Stmt.Exec`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreExec` method, and may be nil. + PostExec func(ctx interface{}, stmt *Stmt, args []driver.Value, result driver.Result) error + + // PreQuery is a callback that gets called prior to calling + // `Stmt.Query`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Stmt.Query` and `Hooks.Query` methods + // are not called. + // + // The first return value is passed to both `Hooks.Query` and + // `Hooks.PostQuery` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Stmt.Query` method and `Hooks.Query` + // methods are not called. + PreQuery func(stmt *Stmt, args []driver.Value) (interface{}, error) + + // Query is called after the underlying driver's `Stmt.Query` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreQuery` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Stmt.Query` method. + Query func(ctx interface{}, stmt *Stmt, args []driver.Value, rows driver.Rows) error + + // PostQuery is a callback that gets called at the end of + // the call to `Stmt.Query`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreQuery` method, and may be nil. + PostQuery func(ctx interface{}, stmt *Stmt, args []driver.Value, rows driver.Rows) error + + // PreBegin is a callback that gets called prior to calling + // `Stmt.Begin`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Conn.Begin` and `Hooks.Begin` methods + // are not called. + // + // The first return value is passed to both `Hooks.Begin` and + // `Hooks.PostBegin` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Conn.Begin` method and `Hooks.Begin` + // methods are not called. + PreBegin func(conn *Conn) (interface{}, error) + + // Begin is called after the underlying driver's `Conn.Begin` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreBegin` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Conn.Begin` method. + Begin func(ctx interface{}, conn *Conn) error + + // PostBegin is a callback that gets called at the end of + // the call to `Conn.Begin`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreBegin` method, and may be nil. + PostBegin func(ctx interface{}, conn *Conn) error + + // PreCommit is a callback that gets called prior to calling + // `Tx.Commit`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Tx.Commit` and `Hooks.Commit` methods + // are not called. + // + // The first return value is passed to both `Hooks.Commit` and + // `Hooks.PostCommit` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Tx.Commit` method and `Hooks.Commit` + // methods are not called. + PreCommit func(tx *Tx) (interface{}, error) + + // Commit is called after the underlying driver's `Tx.Commit` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreCommit` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Tx.Commit` method. + Commit func(ctx interface{}, tx *Tx) error + + // PostCommit is a callback that gets called at the end of + // the call to `Tx.Commit`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreCommit` method, and may be nil. + PostCommit func(ctx interface{}, tx *Tx) error + + // PreRollback is a callback that gets called prior to calling + // `Tx.Rollback`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Tx.Rollback` and `Hooks.Rollback` methods + // are not called. + // + // The first return value is passed to both `Hooks.Rollback` and + // `Hooks.PostRollback` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Tx.Rollback` method and `Hooks.Rollback` + PreRollback func(tx *Tx) (interface{}, error) + + // Rollback is called after the underlying driver's `Tx.Rollback` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreRollback` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Tx.Rollback` method. + Rollback func(ctx interface{}, tx *Tx) error + + // PostRollback is a callback that gets called at the end of + // the call to `Tx.Rollback`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreRollback` method, and may be nil. + PostRollback func(ctx interface{}, tx *Tx) error + + // PreClose is a callback that gets called prior to calling + // `Conn.Close`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Conn.Close` and `Hooks.Close` methods + // are not called. + // + // The first return value is passed to both `Hooks.Close` and + // `Hooks.PostClose` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Conn.Close` method and `Hooks.Close` + // methods are not called. + PreClose func(conn *Conn) (interface{}, error) + + // Close is called after the underlying driver's `Conn.Close` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreClose` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Conn.Close` method. + Close func(ctx interface{}, conn *Conn) error + + // PostClose is a callback that gets called at the end of + // the call to `Conn.Close`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreClose` method, and may be nil. + PostClose func(ctx interface{}, conn *Conn, err error) error + + // PreResetSession is a callback that gets called prior to calling + // `Conn.ResetSession`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `Conn.ResetSession` and `Hooks.ResetSession` methods + // are not called. + // + // The first return value is passed to both `Hooks.ResetSession` and + // `Hooks.PostResetSession` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. If this callback returns an error, + // the underlying driver's `Conn.ResetSession` method and `Hooks.ResetSession` + // methods are not called. + PreResetSession func(conn *Conn) (interface{}, error) + + // ResetSession is called after the underlying driver's `Conn.ResetSession` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreResetSession` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `Conn.ResetSession` method. + ResetSession func(ctx interface{}, conn *Conn) error + + // PostResetSession is a callback that gets called at the end of + // the call to `Conn.ResetSession`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PreResetSession` method, and may be nil. + PostResetSession func(ctx interface{}, conn *Conn, err error) error +} + +func namedValuesToValues(args []driver.NamedValue) ([]driver.Value, error) { + var err error + ret := make([]driver.Value, len(args)) + for _, arg := range args { + if len(arg.Name) > 0 { + err = errors.New("proxy: driver does not support the use of Named Parameters") + } + ret[arg.Ordinal-1] = arg.Value + } + return ret, err +} + +func valuesToNamedValues(args []driver.Value) []driver.NamedValue { + ret := make([]driver.NamedValue, len(args)) + for i, arg := range args { + ret[i] = driver.NamedValue{ + Ordinal: i + 1, + Value: arg, + } + } + return ret +} + +func (h *Hooks) prePing(c context.Context, conn *Conn) (interface{}, error) { + if h == nil || h.PrePing == nil { + return nil, nil + } + return h.PrePing(conn) +} + +func (h *Hooks) ping(c context.Context, ctx interface{}, conn *Conn) error { + if h == nil || h.Ping == nil { + return nil + } + return h.Ping(ctx, conn) +} + +func (h *Hooks) postPing(c context.Context, ctx interface{}, conn *Conn, err error) error { + if h == nil || h.PostPing == nil { + return nil + } + return h.PostPing(ctx, conn, err) +} + +func (h *Hooks) preOpen(c context.Context, name string) (interface{}, error) { + if h == nil || h.PreOpen == nil { + return nil, nil + } + return h.PreOpen(name) +} + +func (h *Hooks) open(c context.Context, ctx interface{}, conn *Conn) error { + if h == nil || h.Open == nil { + return nil + } + return h.Open(ctx, conn) +} + +func (h *Hooks) postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error { + if h == nil || h.PostOpen == nil { + return nil + } + return h.PostOpen(ctx, conn) +} + +func (h *Hooks) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { + if h == nil || h.PreExec == nil { + return nil, nil + } + dargs, _ := namedValuesToValues(args) + return h.PreExec(stmt, dargs) +} + +func (h *Hooks) exec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error { + if h == nil || h.Exec == nil { + return nil + } + dargs, _ := namedValuesToValues(args) + return h.Exec(ctx, stmt, dargs, result) +} + +func (h *Hooks) postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error { + if h == nil || h.PostExec == nil { + return nil + } + dargs, _ := namedValuesToValues(args) + return h.PostExec(ctx, stmt, dargs, result) +} + +func (h *Hooks) preQuery(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { + if h == nil || h.PreQuery == nil { + return nil, nil + } + dargs, _ := namedValuesToValues(args) + return h.PreQuery(stmt, dargs) +} + +func (h *Hooks) query(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows) error { + if h == nil || h.Query == nil { + return nil + } + dargs, _ := namedValuesToValues(args) + return h.Query(ctx, stmt, dargs, rows) +} + +func (h *Hooks) postQuery(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error { + if h == nil || h.PostQuery == nil { + return nil + } + dargs, _ := namedValuesToValues(args) + return h.PostQuery(ctx, stmt, dargs, rows) +} + +func (h *Hooks) preBegin(c context.Context, conn *Conn) (interface{}, error) { + if h == nil || h.PreBegin == nil { + return nil, nil + } + return h.PreBegin(conn) +} + +func (h *Hooks) begin(c context.Context, ctx interface{}, conn *Conn) error { + if h == nil || h.Open == nil { + return nil + } + return h.Begin(ctx, conn) +} + +func (h *Hooks) postBegin(c context.Context, ctx interface{}, conn *Conn, err error) error { + if h == nil || h.PostBegin == nil { + return nil + } + return h.PostBegin(ctx, conn) +} + +func (h *Hooks) preCommit(c context.Context, tx *Tx) (interface{}, error) { + if h == nil || h.PreCommit == nil { + return nil, nil + } + return h.PreCommit(tx) +} + +func (h *Hooks) commit(c context.Context, ctx interface{}, tx *Tx) error { + if h == nil || h.Commit == nil { + return nil + } + return h.Commit(ctx, tx) +} + +func (h *Hooks) postCommit(c context.Context, ctx interface{}, tx *Tx, err error) error { + if h == nil || h.PostCommit == nil { + return nil + } + return h.PostCommit(ctx, tx) +} + +func (h *Hooks) preRollback(c context.Context, tx *Tx) (interface{}, error) { + if h == nil || h.PreRollback == nil { + return nil, nil + } + return h.PreRollback(tx) +} + +func (h *Hooks) rollback(c context.Context, ctx interface{}, tx *Tx) error { + if h == nil || h.Rollback == nil { + return nil + } + return h.Rollback(ctx, tx) +} + +func (h *Hooks) postRollback(c context.Context, ctx interface{}, tx *Tx, err error) error { + if h == nil || h.PostRollback == nil { + return nil + } + return h.PostRollback(ctx, tx) +} + +func (h *Hooks) preClose(c context.Context, conn *Conn) (interface{}, error) { + if h == nil || h.PreClose == nil { + return nil, nil + } + return h.PreClose(conn) +} + +func (h *Hooks) close(c context.Context, ctx interface{}, conn *Conn) error { + if h == nil || h.Close == nil { + return nil + } + return h.Close(ctx, conn) +} + +func (h *Hooks) postClose(c context.Context, ctx interface{}, conn *Conn, err error) error { + if h == nil || h.PostClose == nil { + return nil + } + return h.PostClose(ctx, conn, err) +} + +func (h *Hooks) preResetSession(c context.Context, conn *Conn) (interface{}, error) { + if h == nil || h.PreResetSession == nil { + return nil, nil + } + return h.PreResetSession(conn) +} + +func (h *Hooks) resetSession(c context.Context, ctx interface{}, conn *Conn) error { + if h == nil || h.ResetSession == nil { + return nil + } + return h.ResetSession(ctx, conn) +} + +func (h *Hooks) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error { + if h == nil || h.PostResetSession == nil { + return nil + } + return h.PostResetSession(ctx, conn, err) +} + +type multipleHooks []hooks + +func (h multipleHooks) preDo(f func(h hooks) (interface{}, error)) (interface{}, error) { + if h == nil { + return nil, nil + } + ctx := make([]interface{}, len(h)) + var err error + for i, hk := range h { + ctx0, err0 := f(hk) + ctx[i] = ctx0 + if err0 != nil && err == nil { + err = err0 + } + } + return ctx, err +} + +func (h multipleHooks) do(ctx interface{}, f func(h hooks, ctx interface{}) error) error { + if h == nil { + return nil + } + sctx, ok := ctx.([]interface{}) + if !ok { + return errors.New("invalid context type") + } + for i, hk := range h { + if err := f(hk, sctx[i]); err != nil { + return err + } + } + return nil +} + +func (h multipleHooks) postDo(ctx interface{}, err error, f func(h hooks, ctx interface{}, err error) error) error { + if h == nil { + return nil + } + sctx, ok := ctx.([]interface{}) + if !ok { + return errors.New("invalid context type") + } + var reterr error + for i := len(h) - 1; i >= 0; i-- { + if err0 := f(h[i], sctx[i], err); err0 != nil { + if err == nil { + err = err0 + } + if reterr == nil { + reterr = err0 + } + } + } + return reterr +} + +func (h multipleHooks) prePing(c context.Context, conn *Conn) (interface{}, error) { + return h.preDo(func(h hooks) (interface{}, error) { + return h.prePing(c, conn) + }) +} + +func (h multipleHooks) ping(c context.Context, ctx interface{}, conn *Conn) error { + return h.do(ctx, func(h hooks, ctx interface{}) error { + return h.ping(c, ctx, conn) + }) +} + +func (h multipleHooks) postPing(c context.Context, ctx interface{}, conn *Conn, err error) error { + return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { + return h.postPing(c, ctx, conn, err) + }) +} + +func (h multipleHooks) preOpen(c context.Context, name string) (interface{}, error) { + return h.preDo(func(h hooks) (interface{}, error) { + return h.preOpen(c, name) + }) +} + +func (h multipleHooks) open(c context.Context, ctx interface{}, conn *Conn) error { + return h.do(ctx, func(h hooks, ctx interface{}) error { + return h.open(c, ctx, conn) + }) +} + +func (h multipleHooks) postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error { + return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { + return h.postOpen(c, ctx, conn, err) + }) +} + +func (h multipleHooks) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { + return h.preDo(func(h hooks) (interface{}, error) { + return h.preExec(c, stmt, args) + }) +} + +func (h multipleHooks) exec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error { + return h.do(ctx, func(h hooks, ctx interface{}) error { + return h.exec(c, ctx, stmt, args, result) + }) +} + +func (h multipleHooks) postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error { + return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { + return h.postExec(c, ctx, stmt, args, result, err) + }) +} + +func (h multipleHooks) preQuery(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { + return h.preDo(func(h hooks) (interface{}, error) { + return h.preQuery(c, stmt, args) + }) +} + +func (h multipleHooks) query(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows) error { + return h.do(ctx, func(h hooks, ctx interface{}) error { + return h.query(c, ctx, stmt, args, rows) + }) +} + +func (h multipleHooks) postQuery(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error { + return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { + return h.postQuery(c, ctx, stmt, args, rows, err) + }) +} + +func (h multipleHooks) preBegin(c context.Context, conn *Conn) (interface{}, error) { + return h.preDo(func(h hooks) (interface{}, error) { + return h.preBegin(c, conn) + }) +} + +func (h multipleHooks) begin(c context.Context, ctx interface{}, conn *Conn) error { + return h.do(ctx, func(h hooks, ctx interface{}) error { + return h.begin(c, ctx, conn) + }) +} + +func (h multipleHooks) postBegin(c context.Context, ctx interface{}, conn *Conn, err error) error { + return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { + return h.postBegin(c, ctx, conn, err) + }) +} + +func (h multipleHooks) preCommit(c context.Context, tx *Tx) (interface{}, error) { + return h.preDo(func(h hooks) (interface{}, error) { + return h.preCommit(c, tx) + }) +} + +func (h multipleHooks) commit(c context.Context, ctx interface{}, tx *Tx) error { + return h.do(ctx, func(h hooks, ctx interface{}) error { + return h.commit(c, ctx, tx) + }) +} + +func (h multipleHooks) postCommit(c context.Context, ctx interface{}, tx *Tx, err error) error { + return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { + return h.postCommit(c, ctx, tx, err) + }) +} + +func (h multipleHooks) preRollback(c context.Context, tx *Tx) (interface{}, error) { + return h.preDo(func(h hooks) (interface{}, error) { + return h.preRollback(c, tx) + }) +} + +func (h multipleHooks) rollback(c context.Context, ctx interface{}, tx *Tx) error { + return h.do(ctx, func(h hooks, ctx interface{}) error { + return h.rollback(c, ctx, tx) + }) +} + +func (h multipleHooks) postRollback(c context.Context, ctx interface{}, tx *Tx, err error) error { + return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { + return h.postRollback(c, ctx, tx, err) + }) +} + +func (h multipleHooks) preClose(c context.Context, conn *Conn) (interface{}, error) { + return h.preDo(func(h hooks) (interface{}, error) { + return h.preClose(c, conn) + }) +} + +func (h multipleHooks) close(c context.Context, ctx interface{}, conn *Conn) error { + return h.do(ctx, func(h hooks, ctx interface{}) error { + return h.close(c, ctx, conn) + }) +} + +func (h multipleHooks) postClose(c context.Context, ctx interface{}, conn *Conn, err error) error { + return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { + return h.postClose(c, ctx, conn, err) + }) +} + +func (h multipleHooks) preResetSession(c context.Context, conn *Conn) (interface{}, error) { + return h.preDo(func(h hooks) (interface{}, error) { + return h.preResetSession(c, conn) + }) +} + +func (h multipleHooks) resetSession(c context.Context, ctx interface{}, conn *Conn) error { + return h.do(ctx, func(h hooks, ctx interface{}) error { + return h.resetSession(c, ctx, conn) + }) +} + +func (h multipleHooks) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error { + return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { + return h.postResetSession(c, ctx, conn, err) + }) +} + +type contextHooksKey struct{} + +// WithHooks returns a copy of parent context in which the hooks associated. +func WithHooks(ctx context.Context, hs ...*HooksContext) context.Context { + switch len(hs) { + case 0: + return context.WithValue(ctx, contextHooksKey{}, (*HooksContext)(nil)) + case 1: + return context.WithValue(ctx, contextHooksKey{}, hs[0]) + } + + hooksSlice := make([]hooks, len(hs)) + for i, hk := range hs { + hooksSlice[i] = hk + } + return context.WithValue(ctx, contextHooksKey{}, multipleHooks(hooksSlice)) +} diff --git a/hooks_test.go b/hooks_test.go new file mode 100644 index 0000000..6a1e122 --- /dev/null +++ b/hooks_test.go @@ -0,0 +1,334 @@ +package proxy + +import ( + "context" + "database/sql/driver" + "testing" +) + +func testHooksInterface(t *testing.T, h hooks, ctx interface{}) { + c := context.Background() + if ctx2, err := h.preOpen(c, ""); ctx2 != ctx || err != nil { + t.Errorf("preOpen returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) + } + if err := h.open(c, ctx, nil); err != nil { + t.Error("open returns error: ", err) + } + if err := h.postOpen(c, ctx, nil, nil); err != nil { + t.Error("postOpen returns error: ", err) + } + if ctx2, err := h.prePing(c, nil); ctx2 != ctx || err != nil { + t.Errorf("prePing returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) + } + if err := h.ping(c, ctx, nil); err != nil { + t.Error("ping returns error: ", err) + } + if err := h.postPing(c, ctx, nil, nil); err != nil { + t.Error("postPing returns error: ", err) + } + if ctx2, err := h.preExec(c, nil, nil); ctx2 != ctx || err != nil { + t.Errorf("preExec returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) + } + if err := h.exec(c, ctx, nil, nil, nil); err != nil { + t.Error("exec returns error: ", err) + } + if err := h.postExec(c, ctx, nil, nil, nil, nil); err != nil { + t.Error("postExec returns error: ", err) + } + if ctx2, err := h.preQuery(c, nil, nil); ctx2 != ctx || err != nil { + t.Errorf("preQuery returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) + } + if err := h.query(c, ctx, nil, nil, nil); err != nil { + t.Error("query returns error: ", err) + } + if err := h.postQuery(c, ctx, nil, nil, nil, nil); err != nil { + t.Error("postQuery returns error: ", err) + } + if ctx2, err := h.preBegin(c, nil); ctx2 != ctx || err != nil { + t.Errorf("preBegin returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) + } + if err := h.begin(c, ctx, nil); err != nil { + t.Error("begin returns error: ", err) + } + if err := h.postBegin(c, ctx, nil, nil); err != nil { + t.Error("postBegin returns error: ", err) + } + if ctx2, err := h.preCommit(c, nil); ctx2 != ctx || err != nil { + t.Errorf("preCommit returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) + } + if err := h.commit(c, ctx, nil); err != nil { + t.Error("commit returns error: ", err) + } + if err := h.postCommit(c, ctx, nil, nil); err != nil { + t.Error("postCommit returns error: ", err) + } + if ctx2, err := h.preRollback(c, nil); ctx2 != ctx || err != nil { + t.Errorf("preRollback returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) + } + if err := h.rollback(c, ctx, nil); err != nil { + t.Error("rollback returns error: ", err) + } + if err := h.postRollback(c, ctx, nil, nil); err != nil { + t.Error("postRollback returns error: ", err) + } + if ctx2, err := h.preClose(c, nil); ctx2 != ctx || err != nil { + t.Errorf("preClose returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) + } + if err := h.close(c, ctx, nil); err != nil { + t.Error("close returns error: ", err) + } + if err := h.postClose(c, ctx, nil, nil); err != nil { + t.Error("postClose returns error: ", err) + } + if ctx2, err := h.preResetSession(c, nil); ctx2 != ctx || err != nil { + t.Errorf("preResetSession returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) + } + if err := h.resetSession(c, ctx, nil); err != nil { + t.Error("resetSession returns error: ", err) + } + if err := h.postResetSession(c, ctx, nil, nil); err != nil { + t.Error("postResetSession returns error: ", err) + } +} + +func TestNilHooksContext(t *testing.T) { + // nil HooksContext will not panic and have no effec + testHooksInterface(t, (*HooksContext)(nil), nil) +} + +func TestZeroHooksContext(t *testing.T) { + // zero HooksContext will not panic and have no effec + testHooksInterface(t, &HooksContext{}, nil) +} + +func TestHooksContext(t *testing.T) { + dummy := 0 + ctx0 := &dummy + checkCtx := func(name string, ctx interface{}) { + if ctx != ctx0 { + t.Errorf("unexpected ctx: got %v want %v in %s", ctx, ctx0, name) + } + } + testHooksInterface(t, &HooksContext{ + PrePing: func(c context.Context, conn *Conn) (interface{}, error) { + return ctx0, nil + }, + Ping: func(c context.Context, ctx interface{}, conn *Conn) error { + checkCtx("Ping", ctx) + return nil + }, + PostPing: func(c context.Context, ctx interface{}, conn *Conn, err error) error { + checkCtx("PostPing", ctx) + return err + }, + PreOpen: func(c context.Context, name string) (interface{}, error) { + return ctx0, nil + }, + Open: func(c context.Context, ctx interface{}, conn *Conn) error { + checkCtx("Open", ctx) + return nil + }, + PostOpen: func(c context.Context, ctx interface{}, conn *Conn, err error) error { + checkCtx("PostOpen", ctx) + return err + }, + PreExec: func(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { + return ctx0, nil + }, + Exec: func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error { + checkCtx("Exec", ctx) + return nil + }, + PostExec: func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error { + checkCtx("PostExec", ctx) + return err + }, + PreQuery: func(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { + return ctx0, nil + }, + Query: func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows) error { + checkCtx("Query", ctx) + return nil + }, + PostQuery: func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error { + checkCtx("PostQuery", ctx) + return err + }, + PreBegin: func(c context.Context, conn *Conn) (interface{}, error) { + return ctx0, nil + }, + Begin: func(c context.Context, ctx interface{}, conn *Conn) error { + checkCtx("Begin", ctx) + return nil + }, + PostBegin: func(c context.Context, ctx interface{}, conn *Conn, err error) error { + checkCtx("PostBegin", ctx) + return err + }, + PreCommit: func(c context.Context, tx *Tx) (interface{}, error) { + return ctx0, nil + }, + Commit: func(c context.Context, ctx interface{}, tx *Tx) error { + checkCtx("Commit", ctx) + return nil + }, + PostCommit: func(c context.Context, ctx interface{}, tx *Tx, err error) error { + checkCtx("PostCommit", ctx) + return err + }, + PreRollback: func(c context.Context, tx *Tx) (interface{}, error) { + return ctx0, nil + }, + Rollback: func(c context.Context, ctx interface{}, tx *Tx) error { + checkCtx("Rollback", ctx) + return nil + }, + PostRollback: func(c context.Context, ctx interface{}, tx *Tx, err error) error { + checkCtx("PostRollback", ctx) + return err + }, + PreClose: func(c context.Context, conn *Conn) (interface{}, error) { + return ctx0, nil + }, + Close: func(c context.Context, ctx interface{}, conn *Conn) error { + checkCtx("Close", ctx) + return nil + }, + PostClose: func(c context.Context, ctx interface{}, conn *Conn, err error) error { + checkCtx("PostClose", ctx) + return err + }, + PreResetSession: func(c context.Context, conn *Conn) (interface{}, error) { + return ctx0, nil + }, + ResetSession: func(c context.Context, ctx interface{}, conn *Conn) error { + checkCtx("ResetSession", ctx) + return nil + }, + PostResetSession: func(c context.Context, ctx interface{}, conn *Conn, err error) error { + checkCtx("PostResetSession", ctx) + return err + }, + }, ctx0) +} + +func TestNilHooks(t *testing.T) { + // nil Hooks will not panic and have no effect + testHooksInterface(t, (*Hooks)(nil), nil) +} + +func TestZeroHooks(t *testing.T) { + // zero Hooks will not panic and have no effect + testHooksInterface(t, &Hooks{}, nil) +} + +func TestHooks(t *testing.T) { + dummy := 0 + ctx0 := &dummy + checkCtx := func(name string, ctx interface{}) { + if ctx != ctx0 { + t.Errorf("unexpected ctx: got %v want %v in %s", ctx, ctx0, name) + } + } + testHooksInterface(t, &Hooks{ + PrePing: func(conn *Conn) (interface{}, error) { + return ctx0, nil + }, + Ping: func(ctx interface{}, conn *Conn) error { + checkCtx("Ping", ctx) + return nil + }, + PostPing: func(ctx interface{}, conn *Conn, err error) error { + checkCtx("PostPing", ctx) + return err + }, + PreOpen: func(name string) (interface{}, error) { + return ctx0, nil + }, + Open: func(ctx interface{}, conn *Conn) error { + checkCtx("Open", ctx) + return nil + }, + PostOpen: func(ctx interface{}, conn *Conn) error { + checkCtx("PostOpen", ctx) + return nil + }, + PreExec: func(stmt *Stmt, args []driver.Value) (interface{}, error) { + return ctx0, nil + }, + Exec: func(ctx interface{}, stmt *Stmt, args []driver.Value, result driver.Result) error { + checkCtx("Exec", ctx) + return nil + }, + PostExec: func(ctx interface{}, stmt *Stmt, args []driver.Value, result driver.Result) error { + checkCtx("PostExec", ctx) + return nil + }, + PreQuery: func(stmt *Stmt, args []driver.Value) (interface{}, error) { + return ctx0, nil + }, + Query: func(ctx interface{}, stmt *Stmt, args []driver.Value, rows driver.Rows) error { + checkCtx("Query", ctx) + return nil + }, + PostQuery: func(ctx interface{}, stmt *Stmt, args []driver.Value, rows driver.Rows) error { + checkCtx("PostQuery", ctx) + return nil + }, + PreBegin: func(conn *Conn) (interface{}, error) { + return ctx0, nil + }, + Begin: func(ctx interface{}, conn *Conn) error { + checkCtx("Begin", ctx) + return nil + }, + PostBegin: func(ctx interface{}, conn *Conn) error { + checkCtx("PostBegin", ctx) + return nil + }, + PreCommit: func(tx *Tx) (interface{}, error) { + return ctx0, nil + }, + Commit: func(ctx interface{}, tx *Tx) error { + checkCtx("Commit", ctx) + return nil + }, + PostCommit: func(ctx interface{}, tx *Tx) error { + checkCtx("PostCommit", ctx) + return nil + }, + PreRollback: func(tx *Tx) (interface{}, error) { + return ctx0, nil + }, + Rollback: func(ctx interface{}, tx *Tx) error { + checkCtx("Rollback", ctx) + return nil + }, + PostRollback: func(ctx interface{}, tx *Tx) error { + checkCtx("PostRollback", ctx) + return nil + }, + PreClose: func(conn *Conn) (interface{}, error) { + return ctx0, nil + }, + Close: func(ctx interface{}, conn *Conn) error { + checkCtx("Close", ctx) + return nil + }, + PostClose: func(ctx interface{}, conn *Conn, err error) error { + checkCtx("PostClose", ctx) + return err + }, + PreResetSession: func(conn *Conn) (interface{}, error) { + return ctx0, nil + }, + ResetSession: func(ctx interface{}, conn *Conn) error { + checkCtx("ResetSession", ctx) + return nil + }, + PostResetSession: func(ctx interface{}, conn *Conn, err error) error { + checkCtx("PostResetSession", ctx) + return err + }, + }, ctx0) +} diff --git a/proxy.go b/proxy.go index 74116ce..4f1a4a6 100644 --- a/proxy.go +++ b/proxy.go @@ -5,7 +5,6 @@ package proxy import ( "context" "database/sql/driver" - "errors" ) // Proxy is a sql driver. @@ -15,1244 +14,6 @@ type Proxy struct { hooks hooks } -// hooks is callback functions for the proxy. -// it is private because it doesn't guarantee backward compatibility. -type hooks interface { - prePing(c context.Context, conn *Conn) (interface{}, error) - ping(c context.Context, ctx interface{}, conn *Conn) error - postPing(c context.Context, ctx interface{}, conn *Conn, err error) error - preOpen(c context.Context, name string) (interface{}, error) - open(c context.Context, ctx interface{}, conn *Conn) error - postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error - preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) - exec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error - postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error - preQuery(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) - query(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows) error - postQuery(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error - preBegin(c context.Context, conn *Conn) (interface{}, error) - begin(c context.Context, ctx interface{}, conn *Conn) error - postBegin(c context.Context, ctx interface{}, conn *Conn, err error) error - preCommit(c context.Context, tx *Tx) (interface{}, error) - commit(c context.Context, ctx interface{}, tx *Tx) error - postCommit(c context.Context, ctx interface{}, tx *Tx, err error) error - preRollback(c context.Context, tx *Tx) (interface{}, error) - rollback(c context.Context, ctx interface{}, tx *Tx) error - postRollback(c context.Context, ctx interface{}, tx *Tx, err error) error - preClose(c context.Context, conn *Conn) (interface{}, error) - close(c context.Context, ctx interface{}, conn *Conn) error - postClose(c context.Context, ctx interface{}, conn *Conn, err error) error - preResetSession(c context.Context, conn *Conn) (interface{}, error) - resetSession(c context.Context, ctx interface{}, conn *Conn) error - postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error -} - -// HooksContext is callback functions with context.Context for the proxy. -type HooksContext struct { - // PrePing is a callback that gets called prior to calling - // `Conn.Ping`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Conn.Ping` and `Hooks.Ping` methods - // are not called. - // - // The first return value is passed to both `Hooks.Ping` and - // `Hooks.PostPing` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Conn.Ping` method and `Hooks.Ping` - // methods are not called. - PrePing func(c context.Context, conn *Conn) (interface{}, error) - - // Ping is called after the underlying driver's `Conn.Exec` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PrePing` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Conn.Ping` method. - Ping func(c context.Context, ctx interface{}, conn *Conn) error - - // PostPing is a callback that gets called at the end of - // the call to `Conn.Ping`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PrePing` method, and may be nil. - PostPing func(c context.Context, ctx interface{}, conn *Conn, err error) error - - // PreOpen is a callback that gets called before any - // attempt to open the sql connection is made, and is ALWAYS - // called. - // - // The first return value is passed to both `Hooks.Open` and - // `Hooks.PostOpen` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Driver.Open` method and `Hooks.Open` - // methods are not called. - PreOpen func(c context.Context, name string) (interface{}, error) - - // Open is called after the underlying driver's `Driver.Open` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreOpen` method, and may be nil. - // - // If this callback returns an error, then the `conn` object is - // closed by calling the `Close` method, and the error from this - // callback is returned by the `db.Open` method. - Open func(c context.Context, ctx interface{}, conn *Conn) error - - // PostOpen is a callback that gets called at the end of - // the call to `db.Open(). It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreOpen` method, and may be nil. - PostOpen func(c context.Context, ctx interface{}, conn *Conn, err error) error - - // PreExec is a callback that gets called prior to calling - // `Stmt.Exec`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Stmt.Exec` and `Hooks.Exec` methods - // are not called. - // - // The first return value is passed to both `Hooks.Exec` and - // `Hooks.PostExec` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Driver.Exec` method and `Hooks.Exec` - // methods are not called. - PreExec func(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) - - // Exec is called after the underlying driver's `Driver.Exec` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreExec` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Stmt.Exec` method. - Exec func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error - - // PostExec is a callback that gets called at the end of - // the call to `Stmt.Exec`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreExec` method, and may be nil. - PostExec func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error - - // PreQuery is a callback that gets called prior to calling - // `Stmt.Query`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Stmt.Query` and `Hooks.Query` methods - // are not called. - // - // The first return value is passed to both `Hooks.Query` and - // `Hooks.PostQuery` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Stmt.Query` method and `Hooks.Query` - // methods are not called. - PreQuery func(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) - - // Query is called after the underlying driver's `Stmt.Query` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreQuery` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Stmt.Query` method. - Query func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows) error - - // PostQuery is a callback that gets called at the end of - // the call to `Stmt.Query`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreQuery` method, and may be nil. - PostQuery func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error - - // PreBegin is a callback that gets called prior to calling - // `Stmt.Begin`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Conn.Begin` and `Hooks.Begin` methods - // are not called. - // - // The first return value is passed to both `Hooks.Begin` and - // `Hooks.PostBegin` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Conn.Begin` method and `Hooks.Begin` - // methods are not called. - PreBegin func(c context.Context, conn *Conn) (interface{}, error) - - // Begin is called after the underlying driver's `Conn.Begin` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreBegin` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Conn.Begin` method. - Begin func(c context.Context, ctx interface{}, conn *Conn) error - - // PostBegin is a callback that gets called at the end of - // the call to `Conn.Begin`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreBegin` method, and may be nil. - PostBegin func(c context.Context, ctx interface{}, conn *Conn, err error) error - - // PreCommit is a callback that gets called prior to calling - // `Tx.Commit`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Tx.Commit` and `Hooks.Commit` methods - // are not called. - // - // The first return value is passed to both `Hooks.Commit` and - // `Hooks.PostCommit` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Tx.Commit` method and `Hooks.Commit` - // methods are not called. - PreCommit func(c context.Context, tx *Tx) (interface{}, error) - - // Commit is called after the underlying driver's `Tx.Commit` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreCommit` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Tx.Commit` method. - Commit func(c context.Context, ctx interface{}, tx *Tx) error - - // PostCommit is a callback that gets called at the end of - // the call to `Tx.Commit`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreCommit` method, and may be nil. - PostCommit func(c context.Context, ctx interface{}, tx *Tx, err error) error - - // PreRollback is a callback that gets called prior to calling - // `Tx.Rollback`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Tx.Rollback` and `Hooks.Rollback` methods - // are not called. - // - // The first return value is passed to both `Hooks.Rollback` and - // `Hooks.PostRollback` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Tx.Rollback` method and `Hooks.Rollback` - PreRollback func(c context.Context, tx *Tx) (interface{}, error) - - // Rollback is called after the underlying driver's `Tx.Rollback` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreRollback` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Tx.Rollback` method. - Rollback func(c context.Context, ctx interface{}, tx *Tx) error - - // PostRollback is a callback that gets called at the end of - // the call to `Tx.Rollback`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreRollback` method, and may be nil. - PostRollback func(c context.Context, ctx interface{}, tx *Tx, err error) error - - // PreClose is a callback that gets called prior to calling - // `Conn.Close`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Conn.Close` and `Hooks.Close` methods - // are not called. - // - // The first return value is passed to both `Hooks.Close` and - // `Hooks.PostClose` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Conn.Close` method and `Hooks.Close` - // methods are not called. - PreClose func(c context.Context, conn *Conn) (interface{}, error) - - // Close is called after the underlying driver's `Conn.Close` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreClose` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Conn.Close` method. - Close func(c context.Context, ctx interface{}, conn *Conn) error - - // PostClose is a callback that gets called at the end of - // the call to `Conn.Close`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreClose` method, and may be nil. - PostClose func(c context.Context, ctx interface{}, conn *Conn, err error) error - - // PreResetSession is a callback that gets called prior to calling - // `Conn.ResetSession`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Conn.ResetSession` and `Hooks.ResetSession` methods - // are not called. - // - // The first return value is passed to both `Hooks.ResetSession` and - // `Hooks.PostResetSession` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Conn.ResetSession` method and `Hooks.ResetSession` - // methods are not called. - PreResetSession func(c context.Context, conn *Conn) (interface{}, error) - - // ResetSession is called after the underlying driver's `Conn.ResetSession` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreResetSession` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Conn.ResetSession` method. - ResetSession func(c context.Context, ctx interface{}, conn *Conn) error - - // PostResetSession is a callback that gets called at the end of - // the call to `Conn.ResetSession`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreResetSession` method, and may be nil. - PostResetSession func(c context.Context, ctx interface{}, conn *Conn, err error) error -} - -func (h *HooksContext) prePing(c context.Context, conn *Conn) (interface{}, error) { - if h == nil || h.PrePing == nil { - return nil, nil - } - return h.PrePing(c, conn) -} - -func (h *HooksContext) ping(c context.Context, ctx interface{}, conn *Conn) error { - if h == nil || h.Ping == nil { - return nil - } - return h.Ping(c, ctx, conn) -} - -func (h *HooksContext) postPing(c context.Context, ctx interface{}, conn *Conn, err error) error { - if h == nil || h.PostPing == nil { - return nil - } - return h.PostPing(c, ctx, conn, err) -} - -func (h *HooksContext) preOpen(c context.Context, name string) (interface{}, error) { - if h == nil || h.PreOpen == nil { - return nil, nil - } - return h.PreOpen(c, name) -} - -func (h *HooksContext) open(c context.Context, ctx interface{}, conn *Conn) error { - if h == nil || h.Open == nil { - return nil - } - return h.Open(c, ctx, conn) -} - -func (h *HooksContext) postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error { - if h == nil || h.PostOpen == nil { - return nil - } - return h.PostOpen(c, ctx, conn, err) -} - -func (h *HooksContext) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { - if h == nil || h.PreExec == nil { - return nil, nil - } - return h.PreExec(c, stmt, args) -} - -func (h *HooksContext) exec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error { - if h == nil || h.Exec == nil { - return nil - } - return h.Exec(c, ctx, stmt, args, result) -} - -func (h *HooksContext) postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error { - if h == nil || h.PostExec == nil { - return nil - } - return h.PostExec(c, ctx, stmt, args, result, err) -} - -func (h *HooksContext) preQuery(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { - if h == nil || h.PreQuery == nil { - return nil, nil - } - return h.PreQuery(c, stmt, args) -} - -func (h *HooksContext) query(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows) error { - if h == nil || h.Query == nil { - return nil - } - return h.Query(c, ctx, stmt, args, rows) -} - -func (h *HooksContext) postQuery(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error { - if h == nil || h.PostQuery == nil { - return nil - } - return h.PostQuery(c, ctx, stmt, args, rows, err) -} - -func (h *HooksContext) preBegin(c context.Context, conn *Conn) (interface{}, error) { - if h == nil || h.PreBegin == nil { - return nil, nil - } - return h.PreBegin(c, conn) -} - -func (h *HooksContext) begin(c context.Context, ctx interface{}, conn *Conn) error { - if h == nil || h.Open == nil { - return nil - } - return h.Begin(c, ctx, conn) -} - -func (h *HooksContext) postBegin(c context.Context, ctx interface{}, conn *Conn, err error) error { - if h == nil || h.PostBegin == nil { - return nil - } - return h.PostBegin(c, ctx, conn, err) -} - -func (h *HooksContext) preCommit(c context.Context, tx *Tx) (interface{}, error) { - if h == nil || h.PreCommit == nil { - return nil, nil - } - return h.PreCommit(c, tx) -} - -func (h *HooksContext) commit(c context.Context, ctx interface{}, tx *Tx) error { - if h == nil || h.Commit == nil { - return nil - } - return h.Commit(c, ctx, tx) -} - -func (h *HooksContext) postCommit(c context.Context, ctx interface{}, tx *Tx, err error) error { - if h == nil || h.PostCommit == nil { - return nil - } - return h.PostCommit(c, ctx, tx, err) -} - -func (h *HooksContext) preRollback(c context.Context, tx *Tx) (interface{}, error) { - if h == nil || h.PreRollback == nil { - return nil, nil - } - return h.PreRollback(c, tx) -} - -func (h *HooksContext) rollback(c context.Context, ctx interface{}, tx *Tx) error { - if h == nil || h.Rollback == nil { - return nil - } - return h.Rollback(c, ctx, tx) -} - -func (h *HooksContext) postRollback(c context.Context, ctx interface{}, tx *Tx, err error) error { - if h == nil || h.PostRollback == nil { - return nil - } - return h.PostRollback(c, ctx, tx, err) -} - -func (h *HooksContext) preClose(c context.Context, conn *Conn) (interface{}, error) { - if h == nil || h.PreClose == nil { - return nil, nil - } - return h.PreClose(c, conn) -} - -func (h *HooksContext) close(c context.Context, ctx interface{}, conn *Conn) error { - if h == nil || h.Close == nil { - return nil - } - return h.Close(c, ctx, conn) -} - -func (h *HooksContext) postClose(c context.Context, ctx interface{}, conn *Conn, err error) error { - if h == nil || h.PostClose == nil { - return nil - } - return h.PostClose(c, ctx, conn, err) -} - -func (h *HooksContext) preResetSession(c context.Context, conn *Conn) (interface{}, error) { - if h == nil || h.PreResetSession == nil { - return nil, nil - } - return h.PreResetSession(c, conn) -} - -func (h *HooksContext) resetSession(c context.Context, ctx interface{}, conn *Conn) error { - if h == nil || h.ResetSession == nil { - return nil - } - return h.ResetSession(c, ctx, conn) -} - -func (h *HooksContext) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error { - if h == nil || h.PostResetSession == nil { - return nil - } - return h.PostResetSession(c, ctx, conn, err) -} - -// Hooks is callback functions for the proxy. -// Deprecated: You should use HooksContext instead. -type Hooks struct { - // PrePing is a callback that gets called prior to calling - // `Conn.Ping`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Conn.Ping` and `Hooks.Ping` methods - // are not called. - // - // The first return value is passed to both `Hooks.Ping` and - // `Hooks.PostPing` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Conn.Ping` method and `Hooks.Ping` - // methods are not called. - PrePing func(conn *Conn) (interface{}, error) - - // Ping is called after the underlying driver's `Conn.Exec` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PrePing` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Conn.Ping` method. - Ping func(ctx interface{}, conn *Conn) error - - // PostPing is a callback that gets called at the end of - // the call to `Conn.Ping`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PrePing` method, and may be nil. - PostPing func(ctx interface{}, conn *Conn, err error) error - - // PreOpen is a callback that gets called before any - // attempt to open the sql connection is made, and is ALWAYS - // called. - // - // The first return value is passed to both `Hooks.Open` and - // `Hooks.PostOpen` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Driver.Open` method and `Hooks.Open` - // methods are not called. - PreOpen func(name string) (interface{}, error) - - // Open is called after the underlying driver's `Driver.Open` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreOpen` method, and may be nil. - // - // If this callback returns an error, then the `conn` object is - // closed by calling the `Close` method, and the error from this - // callback is returned by the `db.Open` method. - Open func(ctx interface{}, conn *Conn) error - - // PostOpen is a callback that gets called at the end of - // the call to `db.Open(). It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreOpen` method, and may be nil. - PostOpen func(ctx interface{}, conn *Conn) error - - // PreExec is a callback that gets called prior to calling - // `Stmt.Exec`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Stmt.Exec` and `Hooks.Exec` methods - // are not called. - // - // The first return value is passed to both `Hooks.Exec` and - // `Hooks.PostExec` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Driver.Exec` method and `Hooks.Exec` - // methods are not called. - PreExec func(stmt *Stmt, args []driver.Value) (interface{}, error) - - // Exec is called after the underlying driver's `Driver.Exec` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreExec` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Stmt.Exec` method. - Exec func(ctx interface{}, stmt *Stmt, args []driver.Value, result driver.Result) error - - // PostExec is a callback that gets called at the end of - // the call to `Stmt.Exec`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreExec` method, and may be nil. - PostExec func(ctx interface{}, stmt *Stmt, args []driver.Value, result driver.Result) error - - // PreQuery is a callback that gets called prior to calling - // `Stmt.Query`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Stmt.Query` and `Hooks.Query` methods - // are not called. - // - // The first return value is passed to both `Hooks.Query` and - // `Hooks.PostQuery` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Stmt.Query` method and `Hooks.Query` - // methods are not called. - PreQuery func(stmt *Stmt, args []driver.Value) (interface{}, error) - - // Query is called after the underlying driver's `Stmt.Query` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreQuery` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Stmt.Query` method. - Query func(ctx interface{}, stmt *Stmt, args []driver.Value, rows driver.Rows) error - - // PostQuery is a callback that gets called at the end of - // the call to `Stmt.Query`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreQuery` method, and may be nil. - PostQuery func(ctx interface{}, stmt *Stmt, args []driver.Value, rows driver.Rows) error - - // PreBegin is a callback that gets called prior to calling - // `Stmt.Begin`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Conn.Begin` and `Hooks.Begin` methods - // are not called. - // - // The first return value is passed to both `Hooks.Begin` and - // `Hooks.PostBegin` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Conn.Begin` method and `Hooks.Begin` - // methods are not called. - PreBegin func(conn *Conn) (interface{}, error) - - // Begin is called after the underlying driver's `Conn.Begin` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreBegin` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Conn.Begin` method. - Begin func(ctx interface{}, conn *Conn) error - - // PostBegin is a callback that gets called at the end of - // the call to `Conn.Begin`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreBegin` method, and may be nil. - PostBegin func(ctx interface{}, conn *Conn) error - - // PreCommit is a callback that gets called prior to calling - // `Tx.Commit`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Tx.Commit` and `Hooks.Commit` methods - // are not called. - // - // The first return value is passed to both `Hooks.Commit` and - // `Hooks.PostCommit` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Tx.Commit` method and `Hooks.Commit` - // methods are not called. - PreCommit func(tx *Tx) (interface{}, error) - - // Commit is called after the underlying driver's `Tx.Commit` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreCommit` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Tx.Commit` method. - Commit func(ctx interface{}, tx *Tx) error - - // PostCommit is a callback that gets called at the end of - // the call to `Tx.Commit`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreCommit` method, and may be nil. - PostCommit func(ctx interface{}, tx *Tx) error - - // PreRollback is a callback that gets called prior to calling - // `Tx.Rollback`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Tx.Rollback` and `Hooks.Rollback` methods - // are not called. - // - // The first return value is passed to both `Hooks.Rollback` and - // `Hooks.PostRollback` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Tx.Rollback` method and `Hooks.Rollback` - PreRollback func(tx *Tx) (interface{}, error) - - // Rollback is called after the underlying driver's `Tx.Rollback` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreRollback` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Tx.Rollback` method. - Rollback func(ctx interface{}, tx *Tx) error - - // PostRollback is a callback that gets called at the end of - // the call to `Tx.Rollback`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreRollback` method, and may be nil. - PostRollback func(ctx interface{}, tx *Tx) error - - // PreClose is a callback that gets called prior to calling - // `Conn.Close`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Conn.Close` and `Hooks.Close` methods - // are not called. - // - // The first return value is passed to both `Hooks.Close` and - // `Hooks.PostClose` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Conn.Close` method and `Hooks.Close` - // methods are not called. - PreClose func(conn *Conn) (interface{}, error) - - // Close is called after the underlying driver's `Conn.Close` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreClose` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Conn.Close` method. - Close func(ctx interface{}, conn *Conn) error - - // PostClose is a callback that gets called at the end of - // the call to `Conn.Close`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreClose` method, and may be nil. - PostClose func(ctx interface{}, conn *Conn, err error) error - - // PreResetSession is a callback that gets called prior to calling - // `Conn.ResetSession`, and is ALWAYS called. If this callback returns an - // error, the underlying driver's `Conn.ResetSession` and `Hooks.ResetSession` methods - // are not called. - // - // The first return value is passed to both `Hooks.ResetSession` and - // `Hooks.PostResetSession` callbacks. You may specify anything you want. - // Return nil if you do not need to use it. - // - // The second return value is indicates the error found while - // executing this hook. If this callback returns an error, - // the underlying driver's `Conn.ResetSession` method and `Hooks.ResetSession` - // methods are not called. - PreResetSession func(conn *Conn) (interface{}, error) - - // ResetSession is called after the underlying driver's `Conn.ResetSession` method - // returns without any errors. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreResetSession` method, and may be nil. - // - // If this callback returns an error, then the error from this - // callback is returned by the `Conn.ResetSession` method. - ResetSession func(ctx interface{}, conn *Conn) error - - // PostResetSession is a callback that gets called at the end of - // the call to `Conn.ResetSession`. It is ALWAYS called. - // - // The `ctx` parameter is the return value supplied from the - // `Hooks.PreResetSession` method, and may be nil. - PostResetSession func(ctx interface{}, conn *Conn, err error) error -} - -func namedValuesToValues(args []driver.NamedValue) ([]driver.Value, error) { - var err error - ret := make([]driver.Value, len(args)) - for _, arg := range args { - if len(arg.Name) > 0 { - err = errors.New("proxy: driver does not support the use of Named Parameters") - } - ret[arg.Ordinal-1] = arg.Value - } - return ret, err -} - -func valuesToNamedValues(args []driver.Value) []driver.NamedValue { - ret := make([]driver.NamedValue, len(args)) - for i, arg := range args { - ret[i] = driver.NamedValue{ - Ordinal: i + 1, - Value: arg, - } - } - return ret -} - -func (h *Hooks) prePing(c context.Context, conn *Conn) (interface{}, error) { - if h == nil || h.PrePing == nil { - return nil, nil - } - return h.PrePing(conn) -} - -func (h *Hooks) ping(c context.Context, ctx interface{}, conn *Conn) error { - if h == nil || h.Ping == nil { - return nil - } - return h.Ping(ctx, conn) -} - -func (h *Hooks) postPing(c context.Context, ctx interface{}, conn *Conn, err error) error { - if h == nil || h.PostPing == nil { - return nil - } - return h.PostPing(ctx, conn, err) -} - -func (h *Hooks) preOpen(c context.Context, name string) (interface{}, error) { - if h == nil || h.PreOpen == nil { - return nil, nil - } - return h.PreOpen(name) -} - -func (h *Hooks) open(c context.Context, ctx interface{}, conn *Conn) error { - if h == nil || h.Open == nil { - return nil - } - return h.Open(ctx, conn) -} - -func (h *Hooks) postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error { - if h == nil || h.PostOpen == nil { - return nil - } - return h.PostOpen(ctx, conn) -} - -func (h *Hooks) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { - if h == nil || h.PreExec == nil { - return nil, nil - } - dargs, _ := namedValuesToValues(args) - return h.PreExec(stmt, dargs) -} - -func (h *Hooks) exec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error { - if h == nil || h.Exec == nil { - return nil - } - dargs, _ := namedValuesToValues(args) - return h.Exec(ctx, stmt, dargs, result) -} - -func (h *Hooks) postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error { - if h == nil || h.PostExec == nil { - return nil - } - dargs, _ := namedValuesToValues(args) - return h.PostExec(ctx, stmt, dargs, result) -} - -func (h *Hooks) preQuery(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { - if h == nil || h.PreQuery == nil { - return nil, nil - } - dargs, _ := namedValuesToValues(args) - return h.PreQuery(stmt, dargs) -} - -func (h *Hooks) query(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows) error { - if h == nil || h.Query == nil { - return nil - } - dargs, _ := namedValuesToValues(args) - return h.Query(ctx, stmt, dargs, rows) -} - -func (h *Hooks) postQuery(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error { - if h == nil || h.PostQuery == nil { - return nil - } - dargs, _ := namedValuesToValues(args) - return h.PostQuery(ctx, stmt, dargs, rows) -} - -func (h *Hooks) preBegin(c context.Context, conn *Conn) (interface{}, error) { - if h == nil || h.PreBegin == nil { - return nil, nil - } - return h.PreBegin(conn) -} - -func (h *Hooks) begin(c context.Context, ctx interface{}, conn *Conn) error { - if h == nil || h.Open == nil { - return nil - } - return h.Begin(ctx, conn) -} - -func (h *Hooks) postBegin(c context.Context, ctx interface{}, conn *Conn, err error) error { - if h == nil || h.PostBegin == nil { - return nil - } - return h.PostBegin(ctx, conn) -} - -func (h *Hooks) preCommit(c context.Context, tx *Tx) (interface{}, error) { - if h == nil || h.PreCommit == nil { - return nil, nil - } - return h.PreCommit(tx) -} - -func (h *Hooks) commit(c context.Context, ctx interface{}, tx *Tx) error { - if h == nil || h.Commit == nil { - return nil - } - return h.Commit(ctx, tx) -} - -func (h *Hooks) postCommit(c context.Context, ctx interface{}, tx *Tx, err error) error { - if h == nil || h.PostCommit == nil { - return nil - } - return h.PostCommit(ctx, tx) -} - -func (h *Hooks) preRollback(c context.Context, tx *Tx) (interface{}, error) { - if h == nil || h.PreRollback == nil { - return nil, nil - } - return h.PreRollback(tx) -} - -func (h *Hooks) rollback(c context.Context, ctx interface{}, tx *Tx) error { - if h == nil || h.Rollback == nil { - return nil - } - return h.Rollback(ctx, tx) -} - -func (h *Hooks) postRollback(c context.Context, ctx interface{}, tx *Tx, err error) error { - if h == nil || h.PostRollback == nil { - return nil - } - return h.PostRollback(ctx, tx) -} - -func (h *Hooks) preClose(c context.Context, conn *Conn) (interface{}, error) { - if h == nil || h.PreClose == nil { - return nil, nil - } - return h.PreClose(conn) -} - -func (h *Hooks) close(c context.Context, ctx interface{}, conn *Conn) error { - if h == nil || h.Close == nil { - return nil - } - return h.Close(ctx, conn) -} - -func (h *Hooks) postClose(c context.Context, ctx interface{}, conn *Conn, err error) error { - if h == nil || h.PostClose == nil { - return nil - } - return h.PostClose(ctx, conn, err) -} - -func (h *Hooks) preResetSession(c context.Context, conn *Conn) (interface{}, error) { - if h == nil || h.PreResetSession == nil { - return nil, nil - } - return h.PreResetSession(conn) -} - -func (h *Hooks) resetSession(c context.Context, ctx interface{}, conn *Conn) error { - if h == nil || h.ResetSession == nil { - return nil - } - return h.ResetSession(ctx, conn) -} - -func (h *Hooks) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error { - if h == nil || h.PostResetSession == nil { - return nil - } - return h.PostResetSession(ctx, conn, err) -} - -type multipleHooks []hooks - -func (h multipleHooks) preDo(f func(h hooks) (interface{}, error)) (interface{}, error) { - if h == nil { - return nil, nil - } - ctx := make([]interface{}, len(h)) - var err error - for i, hk := range h { - ctx0, err0 := f(hk) - ctx[i] = ctx0 - if err0 != nil && err == nil { - err = err0 - } - } - return ctx, err -} - -func (h multipleHooks) do(ctx interface{}, f func(h hooks, ctx interface{}) error) error { - if h == nil { - return nil - } - sctx, ok := ctx.([]interface{}) - if !ok { - return errors.New("invalid context type") - } - for i, hk := range h { - if err := f(hk, sctx[i]); err != nil { - return err - } - } - return nil -} - -func (h multipleHooks) postDo(ctx interface{}, err error, f func(h hooks, ctx interface{}, err error) error) error { - if h == nil { - return nil - } - sctx, ok := ctx.([]interface{}) - if !ok { - return errors.New("invalid context type") - } - var reterr error - for i := len(h) - 1; i >= 0; i-- { - if err0 := f(h[i], sctx[i], err); err0 != nil { - if err == nil { - err = err0 - } - if reterr == nil { - reterr = err0 - } - } - } - return reterr -} - -func (h multipleHooks) prePing(c context.Context, conn *Conn) (interface{}, error) { - return h.preDo(func(h hooks) (interface{}, error) { - return h.prePing(c, conn) - }) -} - -func (h multipleHooks) ping(c context.Context, ctx interface{}, conn *Conn) error { - return h.do(ctx, func(h hooks, ctx interface{}) error { - return h.ping(c, ctx, conn) - }) -} - -func (h multipleHooks) postPing(c context.Context, ctx interface{}, conn *Conn, err error) error { - return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { - return h.postPing(c, ctx, conn, err) - }) -} - -func (h multipleHooks) preOpen(c context.Context, name string) (interface{}, error) { - return h.preDo(func(h hooks) (interface{}, error) { - return h.preOpen(c, name) - }) -} - -func (h multipleHooks) open(c context.Context, ctx interface{}, conn *Conn) error { - return h.do(ctx, func(h hooks, ctx interface{}) error { - return h.open(c, ctx, conn) - }) -} - -func (h multipleHooks) postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error { - return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { - return h.postOpen(c, ctx, conn, err) - }) -} - -func (h multipleHooks) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { - return h.preDo(func(h hooks) (interface{}, error) { - return h.preExec(c, stmt, args) - }) -} - -func (h multipleHooks) exec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error { - return h.do(ctx, func(h hooks, ctx interface{}) error { - return h.exec(c, ctx, stmt, args, result) - }) -} - -func (h multipleHooks) postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error { - return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { - return h.postExec(c, ctx, stmt, args, result, err) - }) -} - -func (h multipleHooks) preQuery(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { - return h.preDo(func(h hooks) (interface{}, error) { - return h.preQuery(c, stmt, args) - }) -} - -func (h multipleHooks) query(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows) error { - return h.do(ctx, func(h hooks, ctx interface{}) error { - return h.query(c, ctx, stmt, args, rows) - }) -} - -func (h multipleHooks) postQuery(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error { - return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { - return h.postQuery(c, ctx, stmt, args, rows, err) - }) -} - -func (h multipleHooks) preBegin(c context.Context, conn *Conn) (interface{}, error) { - return h.preDo(func(h hooks) (interface{}, error) { - return h.preBegin(c, conn) - }) -} - -func (h multipleHooks) begin(c context.Context, ctx interface{}, conn *Conn) error { - return h.do(ctx, func(h hooks, ctx interface{}) error { - return h.begin(c, ctx, conn) - }) -} - -func (h multipleHooks) postBegin(c context.Context, ctx interface{}, conn *Conn, err error) error { - return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { - return h.postBegin(c, ctx, conn, err) - }) -} - -func (h multipleHooks) preCommit(c context.Context, tx *Tx) (interface{}, error) { - return h.preDo(func(h hooks) (interface{}, error) { - return h.preCommit(c, tx) - }) -} - -func (h multipleHooks) commit(c context.Context, ctx interface{}, tx *Tx) error { - return h.do(ctx, func(h hooks, ctx interface{}) error { - return h.commit(c, ctx, tx) - }) -} - -func (h multipleHooks) postCommit(c context.Context, ctx interface{}, tx *Tx, err error) error { - return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { - return h.postCommit(c, ctx, tx, err) - }) -} - -func (h multipleHooks) preRollback(c context.Context, tx *Tx) (interface{}, error) { - return h.preDo(func(h hooks) (interface{}, error) { - return h.preRollback(c, tx) - }) -} - -func (h multipleHooks) rollback(c context.Context, ctx interface{}, tx *Tx) error { - return h.do(ctx, func(h hooks, ctx interface{}) error { - return h.rollback(c, ctx, tx) - }) -} - -func (h multipleHooks) postRollback(c context.Context, ctx interface{}, tx *Tx, err error) error { - return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { - return h.postRollback(c, ctx, tx, err) - }) -} - -func (h multipleHooks) preClose(c context.Context, conn *Conn) (interface{}, error) { - return h.preDo(func(h hooks) (interface{}, error) { - return h.preClose(c, conn) - }) -} - -func (h multipleHooks) close(c context.Context, ctx interface{}, conn *Conn) error { - return h.do(ctx, func(h hooks, ctx interface{}) error { - return h.close(c, ctx, conn) - }) -} - -func (h multipleHooks) postClose(c context.Context, ctx interface{}, conn *Conn, err error) error { - return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { - return h.postClose(c, ctx, conn, err) - }) -} - -func (h multipleHooks) preResetSession(c context.Context, conn *Conn) (interface{}, error) { - return h.preDo(func(h hooks) (interface{}, error) { - return h.preResetSession(c, conn) - }) -} - -func (h multipleHooks) resetSession(c context.Context, ctx interface{}, conn *Conn) error { - return h.do(ctx, func(h hooks, ctx interface{}) error { - return h.resetSession(c, ctx, conn) - }) -} - -func (h multipleHooks) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error { - return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { - return h.postResetSession(c, ctx, conn, err) - }) -} - // NewProxy creates new Proxy driver. // DEPRECATED: You should use NewProxyContext instead. func NewProxy(driver driver.Driver, hs ...*Hooks) *Proxy { @@ -1306,24 +67,6 @@ func NewProxyContext(driver driver.Driver, hs ...*HooksContext) *Proxy { } } -type contextHooksKey struct{} - -// WithHooks returns a copy of parent context in which the hooks associated. -func WithHooks(ctx context.Context, hs ...*HooksContext) context.Context { - switch len(hs) { - case 0: - return context.WithValue(ctx, contextHooksKey{}, (*HooksContext)(nil)) - case 1: - return context.WithValue(ctx, contextHooksKey{}, hs[0]) - } - - hooksSlice := make([]hooks, len(hs)) - for i, hk := range hs { - hooksSlice[i] = hk - } - return context.WithValue(ctx, contextHooksKey{}, multipleHooks(hooksSlice)) -} - func (p *Proxy) getHooks(ctx context.Context) hooks { if h, ok := ctx.Value(contextHooksKey{}).(hooks); ok { // Make the caller nil check easy. diff --git a/proxy_test.go b/proxy_test.go index 0303824..86e5e48 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -6,340 +6,12 @@ import ( "bytes" "context" "database/sql" - "database/sql/driver" "encoding/json" "errors" "fmt" "testing" ) -func testHooksInterface(t *testing.T, h hooks, ctx interface{}) { - c := context.Background() - if ctx2, err := h.preOpen(c, ""); ctx2 != ctx || err != nil { - t.Errorf("preOpen returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) - } - if err := h.open(c, ctx, nil); err != nil { - t.Error("open returns error: ", err) - } - if err := h.postOpen(c, ctx, nil, nil); err != nil { - t.Error("postOpen returns error: ", err) - } - if ctx2, err := h.prePing(c, nil); ctx2 != ctx || err != nil { - t.Errorf("prePing returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) - } - if err := h.ping(c, ctx, nil); err != nil { - t.Error("ping returns error: ", err) - } - if err := h.postPing(c, ctx, nil, nil); err != nil { - t.Error("postPing returns error: ", err) - } - if ctx2, err := h.preExec(c, nil, nil); ctx2 != ctx || err != nil { - t.Errorf("preExec returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) - } - if err := h.exec(c, ctx, nil, nil, nil); err != nil { - t.Error("exec returns error: ", err) - } - if err := h.postExec(c, ctx, nil, nil, nil, nil); err != nil { - t.Error("postExec returns error: ", err) - } - if ctx2, err := h.preQuery(c, nil, nil); ctx2 != ctx || err != nil { - t.Errorf("preQuery returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) - } - if err := h.query(c, ctx, nil, nil, nil); err != nil { - t.Error("query returns error: ", err) - } - if err := h.postQuery(c, ctx, nil, nil, nil, nil); err != nil { - t.Error("postQuery returns error: ", err) - } - if ctx2, err := h.preBegin(c, nil); ctx2 != ctx || err != nil { - t.Errorf("preBegin returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) - } - if err := h.begin(c, ctx, nil); err != nil { - t.Error("begin returns error: ", err) - } - if err := h.postBegin(c, ctx, nil, nil); err != nil { - t.Error("postBegin returns error: ", err) - } - if ctx2, err := h.preCommit(c, nil); ctx2 != ctx || err != nil { - t.Errorf("preCommit returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) - } - if err := h.commit(c, ctx, nil); err != nil { - t.Error("commit returns error: ", err) - } - if err := h.postCommit(c, ctx, nil, nil); err != nil { - t.Error("postCommit returns error: ", err) - } - if ctx2, err := h.preRollback(c, nil); ctx2 != ctx || err != nil { - t.Errorf("preRollback returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) - } - if err := h.rollback(c, ctx, nil); err != nil { - t.Error("rollback returns error: ", err) - } - if err := h.postRollback(c, ctx, nil, nil); err != nil { - t.Error("postRollback returns error: ", err) - } - if ctx2, err := h.preClose(c, nil); ctx2 != ctx || err != nil { - t.Errorf("preClose returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) - } - if err := h.close(c, ctx, nil); err != nil { - t.Error("close returns error: ", err) - } - if err := h.postClose(c, ctx, nil, nil); err != nil { - t.Error("postClose returns error: ", err) - } - if ctx2, err := h.preResetSession(c, nil); ctx2 != ctx || err != nil { - t.Errorf("preResetSession returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) - } - if err := h.resetSession(c, ctx, nil); err != nil { - t.Error("resetSession returns error: ", err) - } - if err := h.postResetSession(c, ctx, nil, nil); err != nil { - t.Error("postResetSession returns error: ", err) - } -} - -func TestNilHooksContext(t *testing.T) { - // nil HooksContext will not panic and have no effec - testHooksInterface(t, (*HooksContext)(nil), nil) -} - -func TestZeroHooksContext(t *testing.T) { - // zero HooksContext will not panic and have no effec - testHooksInterface(t, &HooksContext{}, nil) -} - -func TestHooksContext(t *testing.T) { - dummy := 0 - ctx0 := &dummy - checkCtx := func(name string, ctx interface{}) { - if ctx != ctx0 { - t.Errorf("unexpected ctx: got %v want %v in %s", ctx, ctx0, name) - } - } - testHooksInterface(t, &HooksContext{ - PrePing: func(c context.Context, conn *Conn) (interface{}, error) { - return ctx0, nil - }, - Ping: func(c context.Context, ctx interface{}, conn *Conn) error { - checkCtx("Ping", ctx) - return nil - }, - PostPing: func(c context.Context, ctx interface{}, conn *Conn, err error) error { - checkCtx("PostPing", ctx) - return err - }, - PreOpen: func(c context.Context, name string) (interface{}, error) { - return ctx0, nil - }, - Open: func(c context.Context, ctx interface{}, conn *Conn) error { - checkCtx("Open", ctx) - return nil - }, - PostOpen: func(c context.Context, ctx interface{}, conn *Conn, err error) error { - checkCtx("PostOpen", ctx) - return err - }, - PreExec: func(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { - return ctx0, nil - }, - Exec: func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error { - checkCtx("Exec", ctx) - return nil - }, - PostExec: func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error { - checkCtx("PostExec", ctx) - return err - }, - PreQuery: func(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { - return ctx0, nil - }, - Query: func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows) error { - checkCtx("Query", ctx) - return nil - }, - PostQuery: func(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error { - checkCtx("PostQuery", ctx) - return err - }, - PreBegin: func(c context.Context, conn *Conn) (interface{}, error) { - return ctx0, nil - }, - Begin: func(c context.Context, ctx interface{}, conn *Conn) error { - checkCtx("Begin", ctx) - return nil - }, - PostBegin: func(c context.Context, ctx interface{}, conn *Conn, err error) error { - checkCtx("PostBegin", ctx) - return err - }, - PreCommit: func(c context.Context, tx *Tx) (interface{}, error) { - return ctx0, nil - }, - Commit: func(c context.Context, ctx interface{}, tx *Tx) error { - checkCtx("Commit", ctx) - return nil - }, - PostCommit: func(c context.Context, ctx interface{}, tx *Tx, err error) error { - checkCtx("PostCommit", ctx) - return err - }, - PreRollback: func(c context.Context, tx *Tx) (interface{}, error) { - return ctx0, nil - }, - Rollback: func(c context.Context, ctx interface{}, tx *Tx) error { - checkCtx("Rollback", ctx) - return nil - }, - PostRollback: func(c context.Context, ctx interface{}, tx *Tx, err error) error { - checkCtx("PostRollback", ctx) - return err - }, - PreClose: func(c context.Context, conn *Conn) (interface{}, error) { - return ctx0, nil - }, - Close: func(c context.Context, ctx interface{}, conn *Conn) error { - checkCtx("Close", ctx) - return nil - }, - PostClose: func(c context.Context, ctx interface{}, conn *Conn, err error) error { - checkCtx("PostClose", ctx) - return err - }, - PreResetSession: func(c context.Context, conn *Conn) (interface{}, error) { - return ctx0, nil - }, - ResetSession: func(c context.Context, ctx interface{}, conn *Conn) error { - checkCtx("ResetSession", ctx) - return nil - }, - PostResetSession: func(c context.Context, ctx interface{}, conn *Conn, err error) error { - checkCtx("PostResetSession", ctx) - return err - }, - }, ctx0) -} - -func TestNilHooks(t *testing.T) { - // nil Hooks will not panic and have no effect - testHooksInterface(t, (*Hooks)(nil), nil) -} - -func TestZeroHooks(t *testing.T) { - // zero Hooks will not panic and have no effect - testHooksInterface(t, &Hooks{}, nil) -} - -func TestHooks(t *testing.T) { - dummy := 0 - ctx0 := &dummy - checkCtx := func(name string, ctx interface{}) { - if ctx != ctx0 { - t.Errorf("unexpected ctx: got %v want %v in %s", ctx, ctx0, name) - } - } - testHooksInterface(t, &Hooks{ - PrePing: func(conn *Conn) (interface{}, error) { - return ctx0, nil - }, - Ping: func(ctx interface{}, conn *Conn) error { - checkCtx("Ping", ctx) - return nil - }, - PostPing: func(ctx interface{}, conn *Conn, err error) error { - checkCtx("PostPing", ctx) - return err - }, - PreOpen: func(name string) (interface{}, error) { - return ctx0, nil - }, - Open: func(ctx interface{}, conn *Conn) error { - checkCtx("Open", ctx) - return nil - }, - PostOpen: func(ctx interface{}, conn *Conn) error { - checkCtx("PostOpen", ctx) - return nil - }, - PreExec: func(stmt *Stmt, args []driver.Value) (interface{}, error) { - return ctx0, nil - }, - Exec: func(ctx interface{}, stmt *Stmt, args []driver.Value, result driver.Result) error { - checkCtx("Exec", ctx) - return nil - }, - PostExec: func(ctx interface{}, stmt *Stmt, args []driver.Value, result driver.Result) error { - checkCtx("PostExec", ctx) - return nil - }, - PreQuery: func(stmt *Stmt, args []driver.Value) (interface{}, error) { - return ctx0, nil - }, - Query: func(ctx interface{}, stmt *Stmt, args []driver.Value, rows driver.Rows) error { - checkCtx("Query", ctx) - return nil - }, - PostQuery: func(ctx interface{}, stmt *Stmt, args []driver.Value, rows driver.Rows) error { - checkCtx("PostQuery", ctx) - return nil - }, - PreBegin: func(conn *Conn) (interface{}, error) { - return ctx0, nil - }, - Begin: func(ctx interface{}, conn *Conn) error { - checkCtx("Begin", ctx) - return nil - }, - PostBegin: func(ctx interface{}, conn *Conn) error { - checkCtx("PostBegin", ctx) - return nil - }, - PreCommit: func(tx *Tx) (interface{}, error) { - return ctx0, nil - }, - Commit: func(ctx interface{}, tx *Tx) error { - checkCtx("Commit", ctx) - return nil - }, - PostCommit: func(ctx interface{}, tx *Tx) error { - checkCtx("PostCommit", ctx) - return nil - }, - PreRollback: func(tx *Tx) (interface{}, error) { - return ctx0, nil - }, - Rollback: func(ctx interface{}, tx *Tx) error { - checkCtx("Rollback", ctx) - return nil - }, - PostRollback: func(ctx interface{}, tx *Tx) error { - checkCtx("PostRollback", ctx) - return nil - }, - PreClose: func(conn *Conn) (interface{}, error) { - return ctx0, nil - }, - Close: func(ctx interface{}, conn *Conn) error { - checkCtx("Close", ctx) - return nil - }, - PostClose: func(ctx interface{}, conn *Conn, err error) error { - checkCtx("PostClose", ctx) - return err - }, - PreResetSession: func(conn *Conn) (interface{}, error) { - return ctx0, nil - }, - ResetSession: func(ctx interface{}, conn *Conn) error { - checkCtx("ResetSession", ctx) - return nil - }, - PostResetSession: func(ctx interface{}, conn *Conn, err error) error { - checkCtx("PostResetSession", ctx) - return err - }, - }, ctx0) -} - func TestFakeDB(t *testing.T) { testName := t.Name() testCases := []struct { From 565043804241101bb2295a682e257eb8247eef41 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Sat, 4 Jan 2020 17:42:17 +0900 Subject: [PATCH 3/3] add tests for multipleHooks --- hooks.go | 6 ++--- hooks_test.go | 68 +++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/hooks.go b/hooks.go index fc13fe3..1d25038 100644 --- a/hooks.go +++ b/hooks.go @@ -1031,7 +1031,7 @@ func (h *Hooks) postResetSession(c context.Context, ctx interface{}, conn *Conn, type multipleHooks []hooks func (h multipleHooks) preDo(f func(h hooks) (interface{}, error)) (interface{}, error) { - if h == nil { + if len(h) == 0 { return nil, nil } ctx := make([]interface{}, len(h)) @@ -1047,7 +1047,7 @@ func (h multipleHooks) preDo(f func(h hooks) (interface{}, error)) (interface{}, } func (h multipleHooks) do(ctx interface{}, f func(h hooks, ctx interface{}) error) error { - if h == nil { + if len(h) == 0 { return nil } sctx, ok := ctx.([]interface{}) @@ -1063,7 +1063,7 @@ func (h multipleHooks) do(ctx interface{}, f func(h hooks, ctx interface{}) erro } func (h multipleHooks) postDo(ctx interface{}, err error, f func(h hooks, ctx interface{}, err error) error) error { - if h == nil { + if len(h) == 0 { return nil } sctx, ok := ctx.([]interface{}) diff --git a/hooks_test.go b/hooks_test.go index 6a1e122..420660c 100644 --- a/hooks_test.go +++ b/hooks_test.go @@ -3,12 +3,14 @@ package proxy import ( "context" "database/sql/driver" + "reflect" "testing" + "time" ) func testHooksInterface(t *testing.T, h hooks, ctx interface{}) { c := context.Background() - if ctx2, err := h.preOpen(c, ""); ctx2 != ctx || err != nil { + if ctx2, err := h.preOpen(c, ""); !reflect.DeepEqual(ctx2, ctx) || err != nil { t.Errorf("preOpen returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) } if err := h.open(c, ctx, nil); err != nil { @@ -17,7 +19,7 @@ func testHooksInterface(t *testing.T, h hooks, ctx interface{}) { if err := h.postOpen(c, ctx, nil, nil); err != nil { t.Error("postOpen returns error: ", err) } - if ctx2, err := h.prePing(c, nil); ctx2 != ctx || err != nil { + if ctx2, err := h.prePing(c, nil); !reflect.DeepEqual(ctx2, ctx) || err != nil { t.Errorf("prePing returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) } if err := h.ping(c, ctx, nil); err != nil { @@ -26,7 +28,7 @@ func testHooksInterface(t *testing.T, h hooks, ctx interface{}) { if err := h.postPing(c, ctx, nil, nil); err != nil { t.Error("postPing returns error: ", err) } - if ctx2, err := h.preExec(c, nil, nil); ctx2 != ctx || err != nil { + if ctx2, err := h.preExec(c, nil, nil); !reflect.DeepEqual(ctx2, ctx) || err != nil { t.Errorf("preExec returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) } if err := h.exec(c, ctx, nil, nil, nil); err != nil { @@ -35,7 +37,7 @@ func testHooksInterface(t *testing.T, h hooks, ctx interface{}) { if err := h.postExec(c, ctx, nil, nil, nil, nil); err != nil { t.Error("postExec returns error: ", err) } - if ctx2, err := h.preQuery(c, nil, nil); ctx2 != ctx || err != nil { + if ctx2, err := h.preQuery(c, nil, nil); !reflect.DeepEqual(ctx2, ctx) || err != nil { t.Errorf("preQuery returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) } if err := h.query(c, ctx, nil, nil, nil); err != nil { @@ -44,7 +46,7 @@ func testHooksInterface(t *testing.T, h hooks, ctx interface{}) { if err := h.postQuery(c, ctx, nil, nil, nil, nil); err != nil { t.Error("postQuery returns error: ", err) } - if ctx2, err := h.preBegin(c, nil); ctx2 != ctx || err != nil { + if ctx2, err := h.preBegin(c, nil); !reflect.DeepEqual(ctx2, ctx) || err != nil { t.Errorf("preBegin returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) } if err := h.begin(c, ctx, nil); err != nil { @@ -53,7 +55,7 @@ func testHooksInterface(t *testing.T, h hooks, ctx interface{}) { if err := h.postBegin(c, ctx, nil, nil); err != nil { t.Error("postBegin returns error: ", err) } - if ctx2, err := h.preCommit(c, nil); ctx2 != ctx || err != nil { + if ctx2, err := h.preCommit(c, nil); !reflect.DeepEqual(ctx2, ctx) || err != nil { t.Errorf("preCommit returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) } if err := h.commit(c, ctx, nil); err != nil { @@ -62,7 +64,7 @@ func testHooksInterface(t *testing.T, h hooks, ctx interface{}) { if err := h.postCommit(c, ctx, nil, nil); err != nil { t.Error("postCommit returns error: ", err) } - if ctx2, err := h.preRollback(c, nil); ctx2 != ctx || err != nil { + if ctx2, err := h.preRollback(c, nil); !reflect.DeepEqual(ctx2, ctx) || err != nil { t.Errorf("preRollback returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) } if err := h.rollback(c, ctx, nil); err != nil { @@ -71,7 +73,7 @@ func testHooksInterface(t *testing.T, h hooks, ctx interface{}) { if err := h.postRollback(c, ctx, nil, nil); err != nil { t.Error("postRollback returns error: ", err) } - if ctx2, err := h.preClose(c, nil); ctx2 != ctx || err != nil { + if ctx2, err := h.preClose(c, nil); !reflect.DeepEqual(ctx2, ctx) || err != nil { t.Errorf("preClose returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) } if err := h.close(c, ctx, nil); err != nil { @@ -80,7 +82,7 @@ func testHooksInterface(t *testing.T, h hooks, ctx interface{}) { if err := h.postClose(c, ctx, nil, nil); err != nil { t.Error("postClose returns error: ", err) } - if ctx2, err := h.preResetSession(c, nil); ctx2 != ctx || err != nil { + if ctx2, err := h.preResetSession(c, nil); !reflect.DeepEqual(ctx2, ctx) || err != nil { t.Errorf("preResetSession returns unexpected values: got (%v, %v) want (%v, nil)", ctx2, err, ctx) } if err := h.resetSession(c, ctx, nil); err != nil { @@ -92,24 +94,24 @@ func testHooksInterface(t *testing.T, h hooks, ctx interface{}) { } func TestNilHooksContext(t *testing.T) { - // nil HooksContext will not panic and have no effec + // nil HooksContext will not panic and have no effect testHooksInterface(t, (*HooksContext)(nil), nil) } func TestZeroHooksContext(t *testing.T) { - // zero HooksContext will not panic and have no effec + // zero HooksContext will not panic and have no effect testHooksInterface(t, &HooksContext{}, nil) } -func TestHooksContext(t *testing.T) { - dummy := 0 +func newTestHooksContext(t *testing.T) (*HooksContext, interface{}) { + dummy := time.Now().UnixNano() ctx0 := &dummy checkCtx := func(name string, ctx interface{}) { if ctx != ctx0 { t.Errorf("unexpected ctx: got %v want %v in %s", ctx, ctx0, name) } } - testHooksInterface(t, &HooksContext{ + return &HooksContext{ PrePing: func(c context.Context, conn *Conn) (interface{}, error) { return ctx0, nil }, @@ -209,7 +211,12 @@ func TestHooksContext(t *testing.T) { checkCtx("PostResetSession", ctx) return err }, - }, ctx0) + }, ctx0 +} + +func TestHooksContext(t *testing.T) { + hooks, ctx0 := newTestHooksContext(t) + testHooksInterface(t, hooks, ctx0) } func TestNilHooks(t *testing.T) { @@ -222,15 +229,15 @@ func TestZeroHooks(t *testing.T) { testHooksInterface(t, &Hooks{}, nil) } -func TestHooks(t *testing.T) { - dummy := 0 +func newTestHooks(t *testing.T) (*Hooks, interface{}) { + dummy := time.Now().UnixNano() ctx0 := &dummy checkCtx := func(name string, ctx interface{}) { if ctx != ctx0 { t.Errorf("unexpected ctx: got %v want %v in %s", ctx, ctx0, name) } } - testHooksInterface(t, &Hooks{ + return &Hooks{ PrePing: func(conn *Conn) (interface{}, error) { return ctx0, nil }, @@ -330,5 +337,28 @@ func TestHooks(t *testing.T) { checkCtx("PostResetSession", ctx) return err }, - }, ctx0) + }, ctx0 +} + +func TestHooks(t *testing.T) { + hooks, ctx0 := newTestHooks(t) + testHooksInterface(t, hooks, ctx0) +} + +func TestNilMultipleHooks(t *testing.T) { + // nil HooksContext will not panic and have no effect + testHooksInterface(t, multipleHooks(nil), nil) +} + +func TestZeroMultipleHooks(t *testing.T) { + // zero HooksContext will not panic and have no effect + testHooksInterface(t, multipleHooks{}, nil) +} + +func TestMultipleHooks(t *testing.T) { + hooks1, ctx1 := newTestHooksContext(t) + hooks2, ctx2 := newTestHooks(t) + hooks := multipleHooks{hooks1, hooks2} + ctx0 := []interface{}{ctx1, ctx2} + testHooksInterface(t, hooks, ctx0) }