Skip to content

Commit

Permalink
Merge pull request #39 from waleedkadous/feature/api-v2/refresh-tokens
Browse files Browse the repository at this point in the history
Introducing Refresh Tokens with Rotation for a Seamless User Experience
  • Loading branch information
waleedkadous authored Apr 28, 2024
2 parents 79fea69 + 4783593 commit 297ff61
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 70 deletions.
116 changes: 81 additions & 35 deletions ansari_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import os
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Dict

import bcrypt
Expand Down Expand Up @@ -53,15 +53,18 @@ def check_password(self, password, hashed):
# Check if the provided password matches the hash
return bcrypt.checkpw(password.encode(), hashed.encode(self.ENCODING))

def generate_token(self, user_id, token_type="login"):
"""Generate a new token for the user. There are two types of tokens:
- login: This is a token that is used to authenticate the user.
def generate_token(self, user_id, token_type="access", expiry_hours=1):
"""Generate a new token for the user. There are three types of tokens:
- access: This is a token that is used to authenticate the user.
- refresh: This is a token that is used to extend the user session when the access token expires.
- reset: This is a token that is used to reset the user's password.
"""
if token_type not in ["access", "reset", "refresh"]:
raise ValueError("Invalid token type")
payload = {
"user_id": user_id,
"type": token_type,
"exp": datetime.utcnow() + timedelta(days=1),
"exp": datetime.now(timezone.utc) + timedelta(hours=expiry_hours),
}
return jwt.encode(payload, self.token_secret_key, algorithm=self.ALGORITHM)

Expand All @@ -76,13 +79,18 @@ def validate_token(self, request: Request) -> Dict[str, str]:
)
logger.info(f"Payload is {payload}")
# Check that the token is in our database.
cur = self.conn.cursor()
select_cmd = (
"""SELECT user_id FROM user_tokens WHERE user_id = %s AND token = %s;"""
)
cur.execute(select_cmd, (payload["user_id"], token))
result = cur.fetchone()
cur.close()
if payload["type"] == "access":
db_table = "access_tokens"
elif payload["type"] == "refresh":
db_table = "refresh_tokens"
elif payload["type"] == "reset":
db_table = "reset_tokens"
with self.conn.cursor() as cur:
select_cmd = (
f"""SELECT user_id FROM {db_table} WHERE user_id = %s AND token = %s;"""
)
cur.execute(select_cmd, (payload["user_id"], token))
result = cur.fetchone()
if result is None:
logger.warning("Could not find token in database.")
raise HTTPException(
Expand All @@ -97,9 +105,6 @@ def validate_token(self, request: Request) -> Dict[str, str]:
raise HTTPException(
status_code=401, detail="Could not validate credentials"
)
finally:
if cur:
cur.close()

def validate_reset_token(self, token: str) -> Dict[str, str]:
try:
Expand Down Expand Up @@ -160,26 +165,36 @@ def account_exists(self, email):
if cur:
cur.close()

def save_token(self, user_id, token):
def save_access_token(self, user_id, token):
try:
cur = self.conn.cursor()
insert_cmd = "INSERT INTO user_tokens (user_id, token) " + \
"VALUES (%s, %s) ON CONFLICT (user_id) DO UPDATE SET token = %s"
cur.execute(insert_cmd, (user_id, token, token))
self.conn.commit()
return {"status": "success", "token": token}
with self.conn.cursor() as cur:
insert_cmd = f"INSERT INTO access_tokens (user_id, token) " + \
"VALUES (%s, %s) RETURNING id;"
cur.execute(insert_cmd, (user_id, token))
inserted_id = cur.fetchone()[0]
self.conn.commit()
return {"status": "success", "token": token, "token_db_id": inserted_id}
except Exception as e:
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}

def save_refresh_token(self, user_id, token, access_token_id):
try:
with self.conn.cursor() as cur:
insert_cmd = f"INSERT INTO refresh_tokens (user_id, token, access_token_id) " + \
"VALUES (%s, %s, %s);"
cur.execute(insert_cmd, (user_id, token, access_token_id))
self.conn.commit()
return {"status": "success", "token": token}
except Exception as e:
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}
finally:
if cur:
cur.close()

def save_reset_token(self, user_id, token):
try:
cur = self.conn.cursor()
insert_cmd = "INSERT INTO reset_tokens (user_id, token) " + \
"VALUES (%s, %s) ON CONFLICT (user_id) DO UPDATE SET token = %s"
"VALUES (%s, %s) ON CONFLICT (user_id) DO UPDATE SET token = %s;"
cur.execute(insert_cmd, (user_id, token, token))
self.conn.commit()
return {"status": "success", "token": token}
Expand Down Expand Up @@ -430,19 +445,50 @@ def delete_thread(self, thread_id, user_id):
if cur:
cur.close()

def logout(self, user_id):
def delete_access_refresh_tokens_pair(self, refresh_token):
try:
cur = self.conn.cursor()
delete_cmd = """DELETE FROM user_tokens WHERE user_id = %s;"""
cur.execute(delete_cmd, (user_id,))
self.conn.commit()
# get the associated access_token_id
# delete the access_token and refresh_token will be deleted automatically..
# ..because of the foreign key constraint DELETE CASCADE
with self.conn.cursor() as cur:
select_cmd = """SELECT access_token_id FROM refresh_tokens WHERE token = %s;"""
cur.execute(select_cmd, (refresh_token,))
result = cur.fetchone()
if result is None:
raise HTTPException(
status_code=401, detail="Incorrect refresh_token."
)
access_token_id = result[0]
delete_cmd = """DELETE FROM access_tokens WHERE id = %s;"""
cur.execute(delete_cmd, (access_token_id,))
self.conn.commit()
return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
raise HTTPException(status_code=500, detail="Internal server error.")

def delete_access_token(self, user_id, token):
try:
with self.conn.cursor() as cur:
delete_cmd = """DELETE FROM access_tokens WHERE user_id = %s AND token = %s;"""
cur.execute(delete_cmd, (user_id, token))
self.conn.commit()
return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}

def logout(self, user_id, token):
try:
with self.conn.cursor() as cur:
for db_table in ["access_tokens", "refresh_tokens"]:
delete_cmd = f"""DELETE FROM {db_table} WHERE user_id = %s AND token = %s;"""
cur.execute(delete_cmd, (user_id, token))
self.conn.commit()
return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}
finally:
if cur:
cur.close()

def set_pref(self, user_id, key, value):
cur = self.conn.cursor()
Expand All @@ -456,7 +502,7 @@ def set_pref(self, user_id, key, value):
def get_prefs(self, user_id):
cur = self.conn.cursor()
select_cmd = (
"""SELECT pref_key, pref_value FROM preferences WHERE user_id = %s"""
"""SELECT pref_key, pref_value FROM preferences WHERE user_id = %s;"""
)
cur.execute(select_cmd, (user_id,))
result = cur.fetchall()
Expand Down
58 changes: 36 additions & 22 deletions main_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
token_secret_key = os.getenv("SECRET_KEY", "secret")
ALGORITHM = "HS256"
ENCODING = "utf-8"
ACCESS_TOKEN_EXPIRY_HOURS = 2
REFRESH_TOKEN_EXPIRY_HOURS = 24*90
template_dir = "resources/templates"
# Register the UUID type globally
psycopg2.extras.register_uuid()
Expand All @@ -52,7 +54,6 @@
presenter = ApiPresenter(app, ansari)
presenter.present()


def validate_cors(request: Request) -> bool:
try:
logger.info(f"Raw request is {request.headers}")
Expand Down Expand Up @@ -119,11 +120,18 @@ async def login_user(req: LoginRequest, cors_ok: bool = Depends(validate_cors)):
if db.check_password(req.password, existing_hash):
# Generate a token and return it
try:
token = db.generate_token(user_id)
db.save_token(user_id, token)
access_token = db.generate_token(user_id, token_type="access", expiry_hours=ACCESS_TOKEN_EXPIRY_HOURS)
refresh_token = db.generate_token(user_id, token_type="refresh", expiry_hours=REFRESH_TOKEN_EXPIRY_HOURS)
access_token_insert_result = db.save_access_token(user_id, access_token)
if access_token_insert_result["status"] != "success":
raise HTTPException(status_code=500, detail="Couldn't save access token")
refresh_token_insert_result = db.save_refresh_token(user_id, refresh_token, access_token_insert_result["token_db_id"])
if refresh_token_insert_result["status"] != "success":
raise HTTPException(status_code=500, detail="Couldn't save refresh token")
return {
"status": "success",
"token": token,
"access_token": access_token,
"refresh_token": refresh_token,
"first_name": first_name,
"last_name": last_name,
}
Expand All @@ -136,21 +144,31 @@ async def login_user(req: LoginRequest, cors_ok: bool = Depends(validate_cors)):
raise HTTPException(status_code=403, detail="Invalid username or password")


@app.get("/api/v2/users/refresh_token")
@app.post("/api/v2/users/refresh_token")
async def refresh_token(
request: Request,
cors_ok: bool = Depends(validate_cors),
token_params: dict = Depends(db.validate_token),
):
"""Refreshes the token.
Returns a new token on success.
Returns 403 if the password is incorrect or the user doesn't exist.
"""Refreshes both the access token and the refresh token.
Returns the two new tokens on success.
Returns 403 if the refresh_token is invalid, has expired or the user doesn't exist.
"""
if cors_ok and token_params:
if token_params["type"] != "refresh":
raise HTTPException(status_code=403, detail="Invalid token type")
try:
token = db.generate_token(token_params["user_id"])
db.save_token(token_params["user_id"], token)
return {"status": "success", "token": token}
refresh_token = request.headers.get("Authorization", "").split(" ")[1]
db.delete_access_refresh_tokens_pair(refresh_token)
access_token = db.generate_token(token_params["user_id"], token_type="access", expiry_hours=ACCESS_TOKEN_EXPIRY_HOURS)
refresh_token = db.generate_token(token_params["user_id"], token_type="refresh", expiry_hours=REFRESH_TOKEN_EXPIRY_HOURS)
access_token_insert_result = db.save_access_token(token_params["user_id"], access_token)
if access_token_insert_result["status"] != "success":
raise HTTPException(status_code=500, detail="Couldn't save access token")
refresh_token_insert_result = db.save_refresh_token(token_params["user_id"], refresh_token, access_token_insert_result["token_db_id"])
if refresh_token_insert_result["status"] != "success":
raise HTTPException(status_code=500, detail="Couldn't save refresh token")
return {"status": "success", "access_token": access_token, "refresh_token": refresh_token}
except psycopg2.Error as e:
logger.critical(f"Error: {e}")
raise HTTPException(status_code=500, detail="Database error")
Expand All @@ -169,17 +187,14 @@ async def logout_user(
Returns 403 if the password is incorrect or the user doesn't exist.
"""
if cors_ok and token_params:

try:
db.logout(token_params["user_id"])
token = request.headers.get("Authorization", "").split(" ")[1]
db.logout(token_params["user_id"], token)
return {"status": "success"}
except psycopg2.Error as e:
logger.critical(f"Error: {e}")
raise HTTPException(status_code=500, detail="Database error")
else:
raise HTTPException(status_code=403, detail="Invalid username or password")
else:
raise HTTPException(status_code=403, detail="Invalid username or password")
raise HTTPException(status_code=403, detail="Invalid username or password")


class FeedbackRequest(BaseModel):
Expand Down Expand Up @@ -460,6 +475,7 @@ async def request_password_reset(
user_id, _, _, _ = db.retrieve_user_info(req.email)
reset_token = db.generate_token(user_id, "reset")
db.save_reset_token(user_id, reset_token)
# shall we also revoke login and refresh tokens?
tenv = Environment(loader=FileSystemLoader(template_dir))
template = tenv.get_template("password_reset.html")
rendered_template = template.render(reset_token=reset_token)
Expand All @@ -480,13 +496,11 @@ async def request_password_reset(
else:
logger.warning("No sendgrid key")
logger.info(f"Would have sent: {message}")
return {"status": "success"}
except Exception as e:
print(e.message)
else:
# Even if the email doesn't exist, we return success.
# So this can't be used to work out who is on our system.
return {"status": "success"}
# Even if the email doesn't exist, we return success.
# So this can't be used to work out who is on our system.
return {"status": "success"}

else:
raise HTTPException(status_code=403, detail="CORS note permitted.")
Expand Down
8 changes: 8 additions & 0 deletions sql/06_alter_user_tokens.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
BEGIN;
ALTER TABLE user_tokens DROP CONSTRAINT user_tokens_pkey;
ALTER TABLE user_tokens ALTER COLUMN user_id DROP DEFAULT;
ALTER TABLE user_tokens ALTER COLUMN user_id SET NOT NULL;
ALTER TABLE user_tokens RENAME TO access_tokens;
ALTER TABLE access_tokens ADD CONSTRAINT access_tokens_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE access_tokens ADD COLUMN id SERIAL PRIMARY KEY;
COMMIT;
8 changes: 8 additions & 0 deletions sql/07_create_refresh_tokens.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CREATE TABLE refresh_tokens (
id SERIAL PRIMARY KEY,
access_token_id INTEGER NOT NULL,
user_id INTEGER NOT NULL,
token VARCHAR(255) NOT NULL,
FOREIGN KEY (user_id) REFERENCES users(id),
FOREIGN KEY (access_token_id) REFERENCES access_tokens(id) ON DELETE CASCADE
);
Loading

0 comments on commit 297ff61

Please sign in to comment.