Skip to content

Commit

Permalink
Merge pull request #84 from cdleo/main
Browse files Browse the repository at this point in the history
Adding capabilities to change the error returned by the driver
  • Loading branch information
shogo82148 authored Jul 22, 2023
2 parents 7cdac3f + 3c089fb commit 300c00d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 56 deletions.
61 changes: 25 additions & 36 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ type Conn struct {
// It will trigger PrePing, Ping, PostPing hooks.
//
// If the original connection does not satisfy "database/sql/driver".Pinger, it does nothing.
func (conn *Conn) Ping(c context.Context) error {
var err error
func (conn *Conn) Ping(c context.Context) (err error) {
var ctx interface{}
hooks := conn.Proxy.getHooks(c)

if hooks != nil {
defer func() { hooks.postPing(c, ctx, conn, err) }()
defer func() { err = hooks.postPing(c, ctx, conn, err) }()
if ctx, err = hooks.prePing(c, conn); err != nil {
return err
}
Expand All @@ -49,31 +48,30 @@ func (conn *Conn) Prepare(query string) (driver.Stmt, error) {
}

// PrepareContext returns a prepared statement which is wrapped by Stmt.
func (conn *Conn) PrepareContext(c context.Context, query string) (driver.Stmt, error) {
func (conn *Conn) PrepareContext(c context.Context, query string) (stmt driver.Stmt, err error) {
var ctx interface{}
var stmt = &Stmt{
var stmtAux = &Stmt{
QueryString: query,
Proxy: conn.Proxy,
Conn: conn,
}
var err error
hooks := conn.Proxy.getHooks(c)
if hooks != nil {
defer func() { hooks.postPrepare(c, ctx, stmt, err) }()
if ctx, err = hooks.prePrepare(c, stmt); err != nil {
defer func() { err = hooks.postPrepare(c, ctx, stmtAux, err) }()
if ctx, err = hooks.prePrepare(c, stmtAux); err != nil {
return nil, err
}
}

if connCtx, ok := conn.Conn.(driver.ConnPrepareContext); ok {
stmt.Stmt, err = connCtx.PrepareContext(c, stmt.QueryString)
stmtAux.Stmt, err = connCtx.PrepareContext(c, stmtAux.QueryString)
} else {
stmt.Stmt, err = conn.Conn.Prepare(stmt.QueryString)
stmtAux.Stmt, err = conn.Conn.Prepare(stmtAux.QueryString)
if err == nil {
select {
default:
case <-c.Done():
stmt.Stmt.Close()
stmtAux.Stmt.Close()
return nil, c.Err()
}
}
Expand All @@ -83,21 +81,20 @@ func (conn *Conn) PrepareContext(c context.Context, query string) (driver.Stmt,
}

if hooks != nil {
if err = hooks.prepare(c, ctx, stmt); err != nil {
if err = hooks.prepare(c, ctx, stmtAux); err != nil {
return nil, err
}
}
return stmt, nil
return stmtAux, nil
}

// Close calls the original Close method.
func (conn *Conn) Close() error {
func (conn *Conn) Close() (err error) {
ctx := context.Background()
var err error
var myctx interface{}

if hooks := conn.Proxy.hooks; hooks != nil {
defer func() { hooks.postClose(ctx, myctx, conn, err) }()
defer func() { err = hooks.postClose(ctx, myctx, conn, err) }()
if myctx, err = hooks.preClose(ctx, conn); err != nil {
return err
}
Expand All @@ -123,14 +120,12 @@ func (conn *Conn) Begin() (driver.Tx, error) {

// BeginTx starts and returns a new transaction which is wrapped by Tx.
// It will trigger PreBegin, Begin, PostBegin hooks.
func (conn *Conn) BeginTx(c context.Context, opts driver.TxOptions) (driver.Tx, error) {
func (conn *Conn) BeginTx(c context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
// set the hooks.
var err error
var ctx interface{}
var tx driver.Tx
hooks := conn.Proxy.getHooks(c)
if hooks != nil {
defer func() { hooks.postBegin(c, ctx, conn, err) }()
defer func() { err = hooks.postBegin(c, ctx, conn, err) }()
if ctx, err = hooks.preBegin(c, conn); err != nil {
return nil, err
}
Expand Down Expand Up @@ -193,7 +188,7 @@ func (conn *Conn) Exec(query string, args []driver.Value) (driver.Result, error)
// It will trigger PreExec, Exec, PostExec hooks.
//
// If the original connection does not satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error.
func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (drv driver.Result, err error) {
execer, exOk := conn.Conn.(driver.Execer)
execerCtx, exCtxOk := conn.Conn.(driver.ExecerContext)
if !exOk && !exCtxOk {
Expand All @@ -207,19 +202,17 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam
Conn: conn,
}
var ctx interface{}
var err error
var result driver.Result
hooks := conn.Proxy.getHooks(c)
if hooks != nil {
defer func() { hooks.postExec(c, ctx, stmt, args, result, err) }()
defer func() { err = hooks.postExec(c, ctx, stmt, args, drv, err) }()
if ctx, err = hooks.preExec(c, stmt, args); err != nil {
return nil, err
}
}

// call the original method.
if execerCtx != nil {
result, err = execerCtx.ExecContext(c, stmt.QueryString, args)
drv, err = execerCtx.ExecContext(c, stmt.QueryString, args)
} else {
select {
default:
Expand All @@ -230,19 +223,18 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam
if err0 != nil {
return nil, err0
}
result, err = execer.Exec(stmt.QueryString, dargs)
drv, err = execer.Exec(stmt.QueryString, dargs)
}
if err != nil {
return nil, err
}

if hooks != nil {
if err = hooks.exec(c, ctx, stmt, args, result); err != nil {
if err = hooks.exec(c, ctx, stmt, args, drv); err != nil {
return nil, err
}
}

return result, nil
return drv, err
}

// Query executes a query that may return rows.
Expand All @@ -258,7 +250,7 @@ func (conn *Conn) Query(query string, args []driver.Value) (driver.Rows, error)
// It wil trigger PreQuery, Query, PostQuery hooks.
//
// If the original connection does not satisfy "database/sql/driver".QueryerContext nor "database/sql/driver".Queryer, it return ErrSkip error.
func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
queryer, qok := conn.Conn.(driver.Queryer)
queryerCtx, qCtxOk := conn.Conn.(driver.QueryerContext)
if !qok && !qCtxOk {
Expand All @@ -271,11 +263,9 @@ func (conn *Conn) QueryContext(c context.Context, query string, args []driver.Na
Conn: conn,
}
var ctx interface{}
var err error
var rows driver.Rows
hooks := conn.Proxy.getHooks(c)
if hooks != nil {
defer func() { hooks.postQuery(c, ctx, stmt, args, rows, err) }()
defer func() { err = hooks.postQuery(c, ctx, stmt, args, rows, err) }()
if ctx, err = hooks.preQuery(c, stmt, args); err != nil {
return nil, err
}
Expand Down Expand Up @@ -343,13 +333,12 @@ type sessionResetter interface {
}

// ResetSession resets the state of Conn.
func (conn *Conn) ResetSession(ctx context.Context) error {
var err error
func (conn *Conn) ResetSession(ctx context.Context) (err error) {
var myctx interface{}
hooks := conn.Proxy.getHooks(ctx)

if hooks != nil {
defer func() { hooks.postResetSession(ctx, myctx, conn, err) }()
defer func() { err = hooks.postResetSession(ctx, myctx, conn, err) }()
if myctx, err = hooks.preResetSession(ctx, conn); err != nil {
return err
}
Expand Down
20 changes: 10 additions & 10 deletions hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ func (h *HooksContext) ping(c context.Context, ctx interface{}, conn *Conn) erro

func (h *HooksContext) postPing(c context.Context, ctx interface{}, conn *Conn, err error) error {
if h == nil || h.PostPing == nil {
return nil
return err
}
return h.PostPing(c, ctx, conn, err)
}
Expand All @@ -433,7 +433,7 @@ func (h *HooksContext) open(c context.Context, ctx interface{}, conn *Conn) erro

func (h *HooksContext) postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error {
if h == nil || h.PostOpen == nil {
return nil
return err
}
return h.PostOpen(c, ctx, conn, err)
}
Expand All @@ -454,7 +454,7 @@ func (h *HooksContext) prepare(c context.Context, ctx interface{}, stmt *Stmt) e

func (h *HooksContext) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error {
if h == nil || h.PostPrepare == nil {
return nil
return err
}
return h.PostPrepare(c, ctx, stmt, err)
}
Expand All @@ -475,7 +475,7 @@ func (h *HooksContext) exec(c context.Context, ctx interface{}, stmt *Stmt, args

func (h *HooksContext) postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error {
if h == nil || h.PostExec == nil {
return nil
return err
}
return h.PostExec(c, ctx, stmt, args, result, err)
}
Expand All @@ -496,7 +496,7 @@ func (h *HooksContext) query(c context.Context, ctx interface{}, stmt *Stmt, arg

func (h *HooksContext) postQuery(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error {
if h == nil || h.PostQuery == nil {
return nil
return err
}
return h.PostQuery(c, ctx, stmt, args, rows, err)
}
Expand All @@ -517,7 +517,7 @@ func (h *HooksContext) begin(c context.Context, ctx interface{}, conn *Conn) err

func (h *HooksContext) postBegin(c context.Context, ctx interface{}, conn *Conn, err error) error {
if h == nil || h.PostBegin == nil {
return nil
return err
}
return h.PostBegin(c, ctx, conn, err)
}
Expand All @@ -538,7 +538,7 @@ func (h *HooksContext) commit(c context.Context, ctx interface{}, tx *Tx) error

func (h *HooksContext) postCommit(c context.Context, ctx interface{}, tx *Tx, err error) error {
if h == nil || h.PostCommit == nil {
return nil
return err
}
return h.PostCommit(c, ctx, tx, err)
}
Expand All @@ -559,7 +559,7 @@ func (h *HooksContext) rollback(c context.Context, ctx interface{}, tx *Tx) erro

func (h *HooksContext) postRollback(c context.Context, ctx interface{}, tx *Tx, err error) error {
if h == nil || h.PostRollback == nil {
return nil
return err
}
return h.PostRollback(c, ctx, tx, err)
}
Expand All @@ -580,7 +580,7 @@ func (h *HooksContext) close(c context.Context, ctx interface{}, conn *Conn) err

func (h *HooksContext) postClose(c context.Context, ctx interface{}, conn *Conn, err error) error {
if h == nil || h.PostClose == nil {
return nil
return err
}
return h.PostClose(c, ctx, conn, err)
}
Expand All @@ -601,7 +601,7 @@ func (h *HooksContext) resetSession(c context.Context, ctx interface{}, conn *Co

func (h *HooksContext) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error {
if h == nil || h.PostResetSession == nil {
return nil
return err
}
return h.PostResetSession(c, ctx, conn, err)
}
Expand Down
20 changes: 10 additions & 10 deletions logging_hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (h *loggingHook) postPing(c context.Context, ctx interface{}, conn *Conn, e
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostPing]")
return nil
return err
}

func (h *loggingHook) preOpen(c context.Context, name string) (interface{}, error) {
Expand All @@ -58,7 +58,7 @@ func (h *loggingHook) postOpen(c context.Context, ctx interface{}, conn *Conn, e
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostOpen]")
return nil
return err
}

func (h *loggingHook) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) {
Expand All @@ -79,7 +79,7 @@ func (h *loggingHook) postPrepare(c context.Context, ctx interface{}, stmt *Stmt
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostPrepare]")
return nil
return err
}

func (h *loggingHook) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
Expand All @@ -100,7 +100,7 @@ func (h *loggingHook) postExec(c context.Context, ctx interface{}, stmt *Stmt, a
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostExec]")
return nil
return err
}

func (h *loggingHook) preQuery(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
Expand All @@ -121,7 +121,7 @@ func (h *loggingHook) postQuery(c context.Context, ctx interface{}, stmt *Stmt,
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostQuery]")
return nil
return err
}

func (h *loggingHook) preBegin(c context.Context, conn *Conn) (interface{}, error) {
Expand All @@ -142,7 +142,7 @@ func (h *loggingHook) postBegin(c context.Context, ctx interface{}, conn *Conn,
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostBegin]")
return nil
return err
}

func (h *loggingHook) preCommit(c context.Context, tx *Tx) (interface{}, error) {
Expand All @@ -163,7 +163,7 @@ func (h *loggingHook) postCommit(c context.Context, ctx interface{}, tx *Tx, err
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostCommit]")
return nil
return err
}

func (h *loggingHook) preRollback(c context.Context, tx *Tx) (interface{}, error) {
Expand All @@ -184,7 +184,7 @@ func (h *loggingHook) postRollback(c context.Context, ctx interface{}, tx *Tx, e
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostRollback]")
return nil
return err
}

func (h *loggingHook) preClose(c context.Context, conn *Conn) (interface{}, error) {
Expand All @@ -205,7 +205,7 @@ func (h *loggingHook) postClose(c context.Context, ctx interface{}, conn *Conn,
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostClose]")
return nil
return err
}

func (h *loggingHook) preResetSession(c context.Context, conn *Conn) (interface{}, error) {
Expand All @@ -217,7 +217,7 @@ func (h *loggingHook) resetSession(c context.Context, ctx interface{}, conn *Con
}

func (h *loggingHook) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error {
return nil
return err
}

func (h *loggingHook) preIsValid(conn *Conn) (interface{}, error) {
Expand Down

0 comments on commit 300c00d

Please sign in to comment.