Skip to content

Commit

Permalink
Merge pull request #98 from rgooch/recover-rds-connections
Browse files Browse the repository at this point in the history
Recover from broken RDS connections.
  • Loading branch information
rgooch authored Mar 10, 2021
2 parents 492143b + e5106dd commit 08fba05
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
1 change: 1 addition & 0 deletions cmd/keymasterd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ type OpenIDConnectIDPConfig struct {

type ProfileStorageConfig struct {
AwsSecretId string `yaml:"aws_secret_id"`
ConnectionLifetime time.Duration `yaml:"connection_lifetime"`
StorageUrl string `yaml:"storage_url"`
SyncDelay time.Duration `yaml:"sync_delay"`
SyncInterval time.Duration `yaml:"sync_interval"`
Expand Down
27 changes: 20 additions & 7 deletions cmd/keymasterd/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ const (
profileDBFilename = "userProfiles.sqlite3"
cachedDBFilename = "cachedDB.sqlite3"

dbConnectionLifetimeDefault = time.Minute * 15
dbConnectionLifetimeMaximum = time.Hour

dbSyncDelayDefault = time.Second * 3
dbSyncDelayMinimum = time.Second
dbSyncDelayMaximum = time.Minute
Expand All @@ -47,6 +50,14 @@ func (config *ProfileStorageConfig) setSyncLimits() {
} else if config.SyncInterval > dbSyncIntervalMaximum {
config.SyncInterval = dbSyncIntervalMaximum
}
if config.ConnectionLifetime < 1 {
config.ConnectionLifetime = dbConnectionLifetimeDefault
}
if config.ConnectionLifetime < config.SyncInterval {
config.ConnectionLifetime = config.SyncInterval
} else if config.ConnectionLifetime > dbConnectionLifetimeMaximum {
config.ConnectionLifetime = dbConnectionLifetimeMaximum
}
}

func (state *RuntimeState) expandStorageUrl() error {
Expand Down Expand Up @@ -129,6 +140,8 @@ func initDBPostgres(state *RuntimeState) (err error) {
return err
}
}
// Ensure that broken connections are replaced.
state.db.SetConnMaxLifetime(state.Config.ProfileStorage.ConnectionLifetime)
return nil
}

Expand Down Expand Up @@ -382,7 +395,7 @@ func (state *RuntimeState) GetUsers() ([]string, bool, error) {
}
return dbMessage.Names, false, dbMessage.Err
case <-time.After(state.remoteDBQueryTimeout):
logger.Printf("GOT a timeout")
logger.Println("GetUsers: timed out on primary DB")
stmtText := getUsersStmt["sqlite"]
stmt, err := state.cacheDB.Prepare(stmtText)
if err != nil {
Expand All @@ -395,7 +408,7 @@ func (state *RuntimeState) GetUsers() ([]string, bool, error) {
if dbErr != nil {
logger.Printf("Problem with db = '%s'", err)
} else {
logger.Println("GOT data from db cache")
logger.Println("GetUsers: got data from DB cache")
}
return names, true, dbErr
}
Expand Down Expand Up @@ -463,7 +476,7 @@ func (state *RuntimeState) LoadUserProfile(username string) (
metricLogExternalServiceDuration("storage-read", time.Since(start))
profileBytes = dbMessage.ProfileBytes
case <-time.After(state.remoteDBQueryTimeout):
logger.Printf("GOT a timeout")
logger.Println("LoadUserProfile: timed out on primary DB")
fromCache = true
// load from cache
stmtText := loadUserProfileStmt["sqlite"]
Expand All @@ -484,7 +497,7 @@ func (state *RuntimeState) LoadUserProfile(username string) (
return nil, false, true, err
}
}
logger.Printf("GOT data from db cache")
logger.Println("LoadUserProfile: got data from DB cache")
}
logger.Debugf(10, "profile bytes len=%d", len(profileBytes))
//gobReader := bytes.NewReader(fileBytes)
Expand Down Expand Up @@ -643,7 +656,7 @@ func (state *RuntimeState) GetSigned(username string,
}
jwsData = dbMessage.JWSData
case <-time.After(state.remoteDBQueryTimeout):
logger.Printf("GOT a timeout")
logger.Println("GetSigned: timed out on primary DB")
// load from cache
stmtText := getSignedUserDataStmt["sqlite"]
stmt, err := state.cacheDB.Prepare(stmtText)
Expand All @@ -664,9 +677,9 @@ func (state *RuntimeState) GetSigned(username string,
return false, "", err
}
}
logger.Printf("GOT data from db cache")
logger.Println("GetSigned: got data from DB cache")
}
logger.Printf("GOT some jwsdata data")
logger.Println("GetSigned: got jwsdata")
storageJWT, err := state.getStorageDataFromStorageStringDataJWT(jwsData)
if err != nil {
logger.Debugf(2, "failed to get storage data %s data=%s", err, jwsData)
Expand Down

0 comments on commit 08fba05

Please sign in to comment.