-
Notifications
You must be signed in to change notification settings - Fork 11
/
config_utils.py
123 lines (102 loc) · 4.66 KB
/
config_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Module to hold common config logic."""
import collections
import configparser
import logging
import psycopg2
import tenacity
LOGGER = logging.getLogger(__name__)
DatabaseConnectionParams = collections.namedtuple('DatabaseConnectionParams',
['host',
'database_name',
'username',
'password',
'port',
'default_schema',
'sslrootcert',
'sslcert',
'sslkey',
])
def get_database_connection_params_from_config(config):
"""Get database connection params from configparser object.
Args:
config: configparser.ConfigParser from which to extract config params.
Returns:
DatabaseConnectionParams to get a database connection.
"""
return DatabaseConnectionParams(
host=config['POSTGRES']['HOST'],
database_name=config['POSTGRES']['DBNAME'],
username=config['POSTGRES']['USER'],
password=config['POSTGRES']['PASSWORD'],
port=config['POSTGRES']['PORT'],
default_schema=config.get('POSTGRES', 'SCHEMA', fallback=None),
sslrootcert=config.get('POSTGRES', 'SERVER_CA', fallback=None),
sslcert=config.get('POSTGRES', 'CLIENT_CERT', fallback=None),
sslkey=config.get('POSTGRES', 'CLIENT_KEY', fallback=None))
@tenacity.retry(stop=tenacity.stop_after_attempt(4),
wait=tenacity.wait_random_exponential(multiplier=1, max=120),
retry=tenacity.retry_if_exception_type(psycopg2.OperationalError),
reraise=True,
before_sleep=tenacity.before_sleep_log(LOGGER, logging.INFO))
def _get_database_connection_with_retry(db_authorize):
return psycopg2.connect(db_authorize)
def get_database_connection(database_connection_params, retry=True):
"""Get pyscopg2 database connection using the provided params.
Args:
database_connection_params: DatabaseConnectionParams object from which to pull connection
params.
retry: If connection fails due to operational error retry upto 2 additional times.
Returns:
psycopg2.connection ready to be used.
"""
db_authorize = (
"host=%(host)s dbname=%(database_name)s user=%(username)s password=%(password)s "
"port=%(port)s sslmode=require") % database_connection_params._asdict()
if database_connection_params.default_schema:
db_authorize += (
' options=-csearch_path=%(default_schema)s' % database_connection_params._asdict())
if any([database_connection_params.sslrootcert, database_connection_params.sslcert,
database_connection_params.sslkey]):
db_authorize += (
" sslmode=verify-ca sslrootcert=%(sslrootcert)s sslcert=%(sslcert)s sslkey=%(sslkey)s"
) % database_connection_params._asdict()
if retry:
connection = _get_database_connection_with_retry(db_authorize)
else:
connection = psycopg2.connect(db_authorize)
logging.info('Established connecton to %s', connection.dsn)
return connection
def get_database_connection_from_config(config):
"""Get pyscopg2 database connection from the provided ConfigParser.
Args:
config: configparser.ConfigParser initialized from desired file.
Returns:
psycopg2.connection ready to be used.
"""
connection_params = get_database_connection_params_from_config(config)
return get_database_connection(connection_params)
def get_facebook_access_token(config):
return config['FACEBOOK']['TOKEN']
def get_config(config_path):
"""Get configparser object initialized from config path.
Args:
config_path: str file path to config.
Returns:
configparser.ConfigParser initialized from config_path.
"""
config = configparser.ConfigParser()
config.read(config_path)
return config
def configure_logger(log_filename):
"""Configure root logger to write to log_filename and STDOUT.
Args:
log_filename: str, filename to be used for log file.
"""
record_format = (
'[%(levelname)s\t%(asctime)s] %(process)d %(thread)d {%(filename)s:%(lineno)d} '
'%(message)s')
logging.basicConfig(
handlers=[logging.FileHandler(log_filename),
logging.StreamHandler()],
format=record_format,
level=logging.INFO)