Skip to content

Commit

Permalink
Merge pull request #22 from PumpkinSeed/hotfix-load-fail
Browse files Browse the repository at this point in the history
Fix database connection after higher load
  • Loading branch information
PumpkinSeed authored Mar 23, 2021
2 parents cec1fdc + 6d62ec7 commit a6449b8
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 44 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func main() {
log.Fatal(err.Error())
}
t := time.Now()
if err := fuzzer.Run(db, fields, f); err != nil {
if err := fuzzer.Run(fields, f); err != nil {
log.Fatal(err.Error())
}
log.Printf("Fuzzing %s table taken: %v \n", table, time.Since(t))
Expand Down
12 changes: 8 additions & 4 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ func TestFuzz(t *testing.T) {
}
f.Table = "Persons"
f.Parsed = true
f.Num = 10
f.Workers = 2

gofakeit.Seed(0)
driver := drivers.New(f.Driver)
testable := drivers.NewTestable(f.Driver)
db := connector.Connection(driver)
defer connector.Close(driver)
defer db.Close()
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil {
t.Fatal(err)
}
Expand All @@ -39,7 +41,7 @@ func TestFuzz(t *testing.T) {
if err != nil {
t.Fatal(err.Error())
}
err = fuzzer.Run(db, fields, f)
err = fuzzer.Run(fields, f)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -78,12 +80,14 @@ func TestFuzzPostgres(t *testing.T) {
}
f.Table = "Persons"
f.Parsed = true
f.Num = 10
f.Workers = 2

gofakeit.Seed(0)
driver := drivers.New(f.Driver)
testable := drivers.NewTestable(f.Driver)
db := connector.Connection(driver)
defer connector.Close(driver)
defer db.Close()
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil {
t.Fatal(err)
}
Expand All @@ -94,7 +98,7 @@ func TestFuzzPostgres(t *testing.T) {
if err != nil {
t.Fatal(err.Error())
}
err = fuzzer.Run(db, fields, f)
err = fuzzer.Run(fields, f)
if err != nil {
t.Fatal(err)
}
Expand Down
8 changes: 1 addition & 7 deletions pkg/action/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,7 @@ func Insert(db *sql.DB, fields []drivers.FieldDescriptor, driver drivers.Driver,
}
query := driver.Insert(f, table)

ins, err := db.Prepare(query)
if err != nil {
log.Printf("invalid preparing query: %s\n", query)
return fmt.Errorf("error preparing query: %w", err)
}

_, err = ins.Exec(values...)
_, err := db.Exec(query, values...)
return err
}

Expand Down
24 changes: 1 addition & 23 deletions pkg/connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,20 @@ package connector
import (
"database/sql"
"log"
"sync"

"github.com/PumpkinSeed/sqlfuzz/drivers"
_ "github.com/lib/pq"
)

var (
driverDBMap = make(map[string]*sql.DB)
mu = sync.Mutex{}
)

// Connection building a singleton connection to the database for give driver
func Connection(d drivers.Driver) *sql.DB {
mu.Lock()
defer mu.Unlock()
if db, ok := driverDBMap[d.Driver()]; ok {
return db
}
db, err := connect(d)
if err != nil {
log.Fatal(err)
return nil
}
driverDBMap[d.Driver()] = db
return db
}

func Close(d drivers.Driver) error {
mu.Lock()
defer mu.Unlock()
db, ok := driverDBMap[d.Driver()]
if !ok {
return nil
}
delete(driverDBMap, d.Driver())
return db.Close()
return db
}

// connect doing the direct connection open to the SQL database
Expand Down
6 changes: 3 additions & 3 deletions pkg/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ func Get() Flags {
// parse parsing the flags into the f variable
func parse() {
if !f.Parsed {
flag.StringVar(&f.Driver.Username, "u", "", "Username for the database connection")
flag.StringVar(&f.Driver.Password, "p", "", "Password for the database connection")
flag.StringVar(&f.Driver.Database, "d", "", "Database of the database connection")
flag.StringVar(&f.Driver.Username, "u", "test", "Username for the database connection")
flag.StringVar(&f.Driver.Password, "p", "test", "Password for the database connection")
flag.StringVar(&f.Driver.Database, "d", "test", "Database of the database connection")
flag.StringVar(&f.Driver.Host, "h", "localhost", "Host for the database connection")
flag.StringVar(&f.Driver.Port, "P", "3306", "Port for the database connection")
flag.StringVar(&f.Driver.Driver, "D", "mysql", "Driver for the database connection (mysql, postgres, etc.)")
Expand Down
20 changes: 14 additions & 6 deletions pkg/fuzzer/runner.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
package fuzzer

import (
"database/sql"
"log"
"sync"

"github.com/PumpkinSeed/sqlfuzz/drivers"
"github.com/PumpkinSeed/sqlfuzz/pkg/action"
"github.com/PumpkinSeed/sqlfuzz/pkg/connector"
"github.com/PumpkinSeed/sqlfuzz/pkg/flags"
_ "github.com/lib/pq"
)

// Run the commands in a worker pool
func Run(db *sql.DB, fields []drivers.FieldDescriptor, f flags.Flags) error {
func Run(fields []drivers.FieldDescriptor, f flags.Flags) error {
numJobs := f.Num
workers := f.Workers
jobs := make(chan struct{}, numJobs)
wg := &sync.WaitGroup{}
wg.Add(workers)
for w := 0; w < workers; w++ {
go worker(db, jobs, fields, wg, f)
go worker(jobs, fields, wg, f)
}

for j := 0; j < numJobs; j++ {
Expand All @@ -28,13 +28,21 @@ func Run(db *sql.DB, fields []drivers.FieldDescriptor, f flags.Flags) error {
close(jobs)
wg.Wait()

return action.Insert(db, fields, drivers.New(f.Driver), f.Table)
return nil
}

// worker of the worker pool, executing the command, logging if fails
func worker(db *sql.DB, jobs <-chan struct{}, fields []drivers.FieldDescriptor, wg *sync.WaitGroup, f flags.Flags) {
func worker(jobs <-chan struct{}, fields []drivers.FieldDescriptor, wg *sync.WaitGroup, f flags.Flags) {
driver := drivers.New(f.Driver)
db := connector.Connection(driver)
defer func() {
if err := db.Close(); err != nil {
log.Print(err)
}
}()

for range jobs {
if err := action.Insert(db, fields, drivers.New(f.Driver), f.Table); err != nil {
if err := action.Insert(db, fields, driver, f.Table); err != nil {
log.Println(err)
}
}
Expand Down

0 comments on commit a6449b8

Please sign in to comment.