Skip to content

Commit

Permalink
Merge pull request #37 from waleedkadous/feedback_no_token
Browse files Browse the repository at this point in the history
Make shared threads not require any auth
  • Loading branch information
abdullah-alnahas authored Apr 1, 2024
2 parents e9c1462 + 6fe1f1b commit 85f7a56
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 30 deletions.
32 changes: 16 additions & 16 deletions ansari_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def register(self, email, first_name, last_name, password_hash):
self.conn.commit()
return {"status": "success"}
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}
finally:
if cur:
Expand All @@ -154,7 +154,7 @@ def account_exists(self, email):
result = cur.fetchone()
return result is not None
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return False
finally:
if cur:
Expand All @@ -169,7 +169,7 @@ def save_token(self, user_id, token):
self.conn.commit()
return {"status": "success", "token": token}
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}
finally:
if cur:
Expand All @@ -184,7 +184,7 @@ def save_reset_token(self, user_id, token):
self.conn.commit()
return {"status": "success", "token": token}
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}
finally:
if cur:
Expand All @@ -202,7 +202,7 @@ def retrieve_user_info(self, email):
last_name = result[3]
return user_id, existing_hash, first_name, last_name
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return None, None, None, None
finally:
if cur:
Expand All @@ -218,7 +218,7 @@ def add_feedback(self, user_id, thread_id, message_id, feedback_class, comment):
)
return {"status": "success"}
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}
finally:
if cur:
Expand All @@ -234,7 +234,7 @@ def create_thread(self, user_id):
return {"status": "success", "thread_id": inserted_id}

except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}

finally:
Expand All @@ -254,7 +254,7 @@ def get_all_threads(self, user_id):
for x in result
]
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return []
finally:
if cur:
Expand All @@ -277,7 +277,7 @@ def set_thread_name(self, thread_id, user_id, thread_name):
self.conn.commit()
return {"status": "success"}
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}
finally:
if cur:
Expand All @@ -296,7 +296,7 @@ def append_message(self, user_id, thread_id, role, content, function_name=None):
self.conn.commit()
return {"status": "success"}
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}
finally:
if cur:
Expand Down Expand Up @@ -360,7 +360,7 @@ def get_thread_llm(self, thread_id, user_id):
}
return retval
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {}
finally:
if cur:
Expand All @@ -384,7 +384,7 @@ def snapshot_thread(self, thread_id, user_id):
logger.info(f"Result is {result}")
return result
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}
finally:
if cur:
Expand All @@ -399,7 +399,7 @@ def get_snapshot(self, share_uuid):
result = cur.fetchone()[0]
return result
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {}
finally:
if cur:
Expand All @@ -420,7 +420,7 @@ def delete_thread(self, thread_id, user_id):
self.conn.commit()
return {"status": "success"}
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}
finally:
if cur:
Expand All @@ -434,7 +434,7 @@ def logout(self, user_id):
self.conn.commit()
return {"status": "success"}
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}
finally:
if cur:
Expand Down Expand Up @@ -471,7 +471,7 @@ def update_password(self, user_id, new_password_hash):
cur.close()
return {"status": "success"}
except Exception as e:
logger.warning("Error is ", e)
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}
finally:
if cur:
Expand Down
19 changes: 6 additions & 13 deletions main_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,26 +324,19 @@ def share_thread(
def get_snapshot(
share_uuid_str: str,
cors_ok: bool = Depends(validate_cors),
token_params: dict = Depends(db.validate_token),
):
"""
Take a snapshot of a thread at this time and make it shareable.
"""
logger.info(f"Incoming share_uuid is {share_uuid_str}")
share_uuid = uuid.UUID(share_uuid_str)
if cors_ok and token_params:
logger.info(f"Token_params is {token_params}")
# TODO(mwk): check that the user_id in the token matches the
# user_id associated with the thread_id.
try:
content = db.get_snapshot(share_uuid)
return {"status": "success", "content": content}
except psycopg2.Error as e:
logger.critical(f"Error: {e}")
raise HTTPException(status_code=500, detail="Database error")
else:
raise HTTPException(status_code=403, detail="CORS not permitted")
try:
content = db.get_snapshot(share_uuid)
return {"status": "success", "content": content}
except psycopg2.Error as e:
logger.critical(f"Error: {e}")
raise HTTPException(status_code=500, detail="Database error")


@app.get("/api/v2/threads/{thread_id}")
Expand Down
3 changes: 2 additions & 1 deletion tests/test_main_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ async def test_share_thread(login_user, create_thread):
response = client.get(
f"api/v2/share/{share_uuid}",
headers={
"Authorization": f"Bearer {login_user}",
# NOTE: We do not need to pass the Authorization header here
# Accessing a shared thread does not require authentication
"x-mobile-ansari": "ANSARI",
},
)
Expand Down

0 comments on commit 85f7a56

Please sign in to comment.