diff --git a/main.go b/main.go index 79f978d..3b17c1e 100644 --- a/main.go +++ b/main.go @@ -36,7 +36,7 @@ func main() { log.Fatal(err.Error()) } t := time.Now() - if err := fuzzer.Run(db, fields, f); err != nil { + if err := fuzzer.Run(fields, f); err != nil { log.Fatal(err.Error()) } log.Printf("Fuzzing %s table taken: %v \n", table, time.Since(t)) diff --git a/main_test.go b/main_test.go index 2e28b39..dcdcb49 100644 --- a/main_test.go +++ b/main_test.go @@ -23,12 +23,14 @@ func TestFuzz(t *testing.T) { } f.Table = "Persons" f.Parsed = true + f.Num = 10 + f.Workers = 2 gofakeit.Seed(0) driver := drivers.New(f.Driver) testable := drivers.NewTestable(f.Driver) db := connector.Connection(driver) - defer connector.Close(driver) + defer db.Close() if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil { t.Fatal(err) } @@ -39,7 +41,7 @@ func TestFuzz(t *testing.T) { if err != nil { t.Fatal(err.Error()) } - err = fuzzer.Run(db, fields, f) + err = fuzzer.Run(fields, f) if err != nil { t.Fatal(err) } @@ -78,12 +80,14 @@ func TestFuzzPostgres(t *testing.T) { } f.Table = "Persons" f.Parsed = true + f.Num = 10 + f.Workers = 2 gofakeit.Seed(0) driver := drivers.New(f.Driver) testable := drivers.NewTestable(f.Driver) db := connector.Connection(driver) - defer connector.Close(driver) + defer db.Close() if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", f.Table)); err != nil { t.Fatal(err) } @@ -94,7 +98,7 @@ func TestFuzzPostgres(t *testing.T) { if err != nil { t.Fatal(err.Error()) } - err = fuzzer.Run(db, fields, f) + err = fuzzer.Run(fields, f) if err != nil { t.Fatal(err) } diff --git a/pkg/action/action.go b/pkg/action/action.go index a207561..09298c6 100644 --- a/pkg/action/action.go +++ b/pkg/action/action.go @@ -31,13 +31,7 @@ func Insert(db *sql.DB, fields []drivers.FieldDescriptor, driver drivers.Driver, } query := driver.Insert(f, table) - ins, err := db.Prepare(query) - if err != nil { - log.Printf("invalid preparing query: %s\n", query) - return fmt.Errorf("error preparing query: %w", err) - } - - _, err = ins.Exec(values...) + _, err := db.Exec(query, values...) return err } diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index e8a309a..a2bd95c 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -3,42 +3,20 @@ package connector import ( "database/sql" "log" - "sync" "github.com/PumpkinSeed/sqlfuzz/drivers" _ "github.com/lib/pq" ) -var ( - driverDBMap = make(map[string]*sql.DB) - mu = sync.Mutex{} -) - // Connection building a singleton connection to the database for give driver func Connection(d drivers.Driver) *sql.DB { - mu.Lock() - defer mu.Unlock() - if db, ok := driverDBMap[d.Driver()]; ok { - return db - } db, err := connect(d) if err != nil { log.Fatal(err) return nil } - driverDBMap[d.Driver()] = db - return db -} -func Close(d drivers.Driver) error { - mu.Lock() - defer mu.Unlock() - db, ok := driverDBMap[d.Driver()] - if !ok { - return nil - } - delete(driverDBMap, d.Driver()) - return db.Close() + return db } // connect doing the direct connection open to the SQL database diff --git a/pkg/flags/flags.go b/pkg/flags/flags.go index bda929e..807df44 100644 --- a/pkg/flags/flags.go +++ b/pkg/flags/flags.go @@ -30,9 +30,9 @@ func Get() Flags { // parse parsing the flags into the f variable func parse() { if !f.Parsed { - flag.StringVar(&f.Driver.Username, "u", "", "Username for the database connection") - flag.StringVar(&f.Driver.Password, "p", "", "Password for the database connection") - flag.StringVar(&f.Driver.Database, "d", "", "Database of the database connection") + flag.StringVar(&f.Driver.Username, "u", "test", "Username for the database connection") + flag.StringVar(&f.Driver.Password, "p", "test", "Password for the database connection") + flag.StringVar(&f.Driver.Database, "d", "test", "Database of the database connection") flag.StringVar(&f.Driver.Host, "h", "localhost", "Host for the database connection") flag.StringVar(&f.Driver.Port, "P", "3306", "Port for the database connection") flag.StringVar(&f.Driver.Driver, "D", "mysql", "Driver for the database connection (mysql, postgres, etc.)") diff --git a/pkg/fuzzer/runner.go b/pkg/fuzzer/runner.go index 035c380..ac11872 100644 --- a/pkg/fuzzer/runner.go +++ b/pkg/fuzzer/runner.go @@ -1,25 +1,25 @@ package fuzzer import ( - "database/sql" "log" "sync" "github.com/PumpkinSeed/sqlfuzz/drivers" "github.com/PumpkinSeed/sqlfuzz/pkg/action" + "github.com/PumpkinSeed/sqlfuzz/pkg/connector" "github.com/PumpkinSeed/sqlfuzz/pkg/flags" _ "github.com/lib/pq" ) // Run the commands in a worker pool -func Run(db *sql.DB, fields []drivers.FieldDescriptor, f flags.Flags) error { +func Run(fields []drivers.FieldDescriptor, f flags.Flags) error { numJobs := f.Num workers := f.Workers jobs := make(chan struct{}, numJobs) wg := &sync.WaitGroup{} wg.Add(workers) for w := 0; w < workers; w++ { - go worker(db, jobs, fields, wg, f) + go worker(jobs, fields, wg, f) } for j := 0; j < numJobs; j++ { @@ -28,13 +28,21 @@ func Run(db *sql.DB, fields []drivers.FieldDescriptor, f flags.Flags) error { close(jobs) wg.Wait() - return action.Insert(db, fields, drivers.New(f.Driver), f.Table) + return nil } // worker of the worker pool, executing the command, logging if fails -func worker(db *sql.DB, jobs <-chan struct{}, fields []drivers.FieldDescriptor, wg *sync.WaitGroup, f flags.Flags) { +func worker(jobs <-chan struct{}, fields []drivers.FieldDescriptor, wg *sync.WaitGroup, f flags.Flags) { + driver := drivers.New(f.Driver) + db := connector.Connection(driver) + defer func() { + if err := db.Close(); err != nil { + log.Print(err) + } + }() + for range jobs { - if err := action.Insert(db, fields, drivers.New(f.Driver), f.Table); err != nil { + if err := action.Insert(db, fields, driver, f.Table); err != nil { log.Println(err) } }