internal.auth.security
Sensitive information manimulation.
1"""Sensitive information manimulation.""" 2 3import asyncio 4import json 5from secrets import choice, token_urlsafe 6 7from bcrypt import checkpw, gensalt, hashpw 8from fastapi import HTTPException, Request, status 9from pydantic import BaseModel, SecretStr 10from sqlalchemy.exc import SQLAlchemyError 11from sqlalchemy.ext.asyncio import AsyncConnection 12 13from internal.database.manager import database_manager 14from internal.logger.logger import logger 15from internal.queries.activity_log import AsyncQuerier as ActivityLogQuerier 16from internal.queries.activity_log import CreateActivityLogParams 17from internal.queries.models import UserRole 18from internal.queries.token import AsyncQuerier as TokenQuerier 19from internal.queries.user import AsyncQuerier as UserQuerier 20from internal.queries.user import UpdateUserPasswordParams, UpdateUserPasswordRow 21 22SENSITIVE_FIELDS = { 23 "password", 24 "pw_hash", 25 "token", 26 "claim_code", 27 "email", 28 "new_password", 29} 30 31 32class LogData(BaseModel): 33 """Data for activity log.""" 34 35 user_id: int | None 36 user_role: UserRole | None 37 method: str 38 path: str 39 query_params: dict[str, str] 40 ip_address: str | None 41 body: dict[str, object] | None 42 43 44def sanitize_body(body: dict[str, object]) -> dict[str, object]: 45 """Redact sensitive fields from request body. 46 47 Returns: 48 Body dict with sensitive fields redacted. 49 """ 50 result: dict[str, object] = {} 51 for key, value in body.items(): 52 if key.lower() in SENSITIVE_FIELDS: 53 result[key] = "REDACTED" 54 else: 55 result[key] = value 56 return result 57 58 59async def get_user_from_token(token: str) -> tuple[int, UserRole] | tuple[None, None]: 60 """Get user_id and role from Bearer token. 61 62 Returns: 63 Tuple of (user_id, user_role) or (None, None) if not found. 64 """ 65 try: 66 async for conn in database_manager.get_connection(): 67 session = await TokenQuerier(conn).get_session_by_token(token=token) 68 if session: 69 return session.user_id, session.role 70 except SQLAlchemyError: 71 logger.exception("Database error fetching user from token") 72 return None, None 73 74 75async def log_to_db(log_data: LogData) -> None: 76 """Log activity to database.""" 77 details: str | None = None 78 if log_data.body: 79 details = json.dumps({"body": log_data.body}) 80 81 try: 82 async for conn in database_manager.get_connection(): 83 await ActivityLogQuerier(conn).create_activity_log( 84 CreateActivityLogParams( 85 user_id=log_data.user_id, 86 action=f"{log_data.method} {log_data.path}", 87 details=details, 88 ip_address=log_data.ip_address or "", 89 ) 90 ) 91 await conn.commit() 92 except SQLAlchemyError: 93 logger.exception("Failed to log to database") 94 return 95 96 97async def log_request(request: Request) -> None: 98 """Log request details to activity_log table.""" 99 body = None 100 content_type = request.headers.get("content-type", "") 101 if ( 102 request.method in {"POST", "PATCH", "PUT"} 103 and "application/json" in content_type 104 ): 105 try: 106 raw_body = await request.json() 107 if isinstance(raw_body, dict): 108 body = sanitize_body(raw_body) 109 except json.JSONDecodeError, KeyError: 110 pass # no body or invalid json 111 112 user_id: int | None = None 113 user_role: UserRole | None = None 114 auth_header = request.headers.get("authorization") 115 is_bearer = auth_header is not None and auth_header.startswith("Bearer ") 116 if is_bearer: 117 auth_values = auth_header 118 if auth_values is None: 119 logger.exception("Failed to read the auth header.") 120 return 121 token = auth_values[7:] 122 result = await get_user_from_token(token) 123 user_id, user_role = result 124 125 should_log = user_id or not is_bearer 126 if should_log: 127 log_data = LogData( 128 user_id=user_id, 129 user_role=user_role, 130 method=request.method, 131 path=str(request.url.path), 132 query_params=dict(request.query_params), 133 ip_address=request.client.host if request.client else None, 134 body=body, 135 ) 136 await log_to_db(log_data) 137 138 139def hash_password(password: str) -> str: 140 """Secure password hashing. 141 142 Args: 143 password: plain text password 144 145 Returns: 146 hashed password with salt 147 """ 148 pasword_bytes = password.encode("utf-8") 149 salt = gensalt(rounds=12) 150 return hashpw(pasword_bytes, salt).decode("utf-8") 151 152 153# check password against the hash 154def check_password(password: str, password_hash: str) -> bool: 155 """Checks if password corresponds with given password hash. 156 157 Args: 158 password: plain text password 159 password_hash: hash that password gets check against 160 161 Returns: 162 if password is corresponds to a password hash 163 """ 164 password_bytes = password.encode("utf-8") 165 password_hash_bytes = password_hash.encode("utf-8") 166 return checkpw(password_bytes, password_hash_bytes) 167 168 169# generate url safe 256 bits 170def generate_token() -> str: 171 """Generates secure token for use anywhere on web. 172 173 Returns: 174 rendomly generated token as a string 175 """ 176 return token_urlsafe(32) 177 178 179class UpdatePasswordForm(BaseModel): 180 """User form to update password.""" 181 182 old_password: SecretStr 183 new_password: SecretStr 184 185 186async def update_pw( 187 email: str, form: UpdatePasswordForm, conn: AsyncConnection 188) -> UpdateUserPasswordRow: 189 """Update user password with old password check. 190 191 Args: 192 email: user email 193 form: form with information to update password 194 conn: database connection 195 196 Returns: 197 updated user record 198 199 Raises: 200 HTTPException: if failed to perform update 201 """ 202 querier = UserQuerier(conn) 203 user = await querier.get_user_login(email=email) 204 if not user: 205 raise HTTPException( 206 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 207 detail="Failed to find user", 208 ) 209 is_valid = await asyncio.to_thread( 210 check_password, form.old_password.get_secret_value(), user.pw_hash 211 ) 212 if not is_valid: 213 raise HTTPException( 214 status_code=status.HTTP_403_FORBIDDEN, detail="Old password is incorrect" 215 ) 216 new_hashed_pw = await asyncio.to_thread( 217 hash_password, form.new_password.get_secret_value() 218 ) 219 user_updated = await querier.update_user_password( 220 UpdateUserPasswordParams(user_id=user.user_id, pw_hash=new_hashed_pw) 221 ) 222 if not user_updated: 223 raise HTTPException( 224 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 225 detail="Failed to update password", 226 ) 227 return user_updated 228 229 230def generate_claim_code(used_codes: list[str]) -> str: 231 """Generate claim code. 232 233 Args: 234 used_codes: list of codes to avoid 235 236 Returns: 237 claim code 238 """ 239 while True: 240 alphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" 241 raw_code = "".join(choice(alphabet) for _ in range(4)) 242 code = f"{raw_code[:2]}-{raw_code[2:]}" 243 if code not in used_codes: 244 return code
SENSITIVE_FIELDS =
{'pw_hash', 'new_password', 'claim_code', 'token', 'email', 'password'}
class
LogData(pydantic.main.BaseModel):
33class LogData(BaseModel): 34 """Data for activity log.""" 35 36 user_id: int | None 37 user_role: UserRole | None 38 method: str 39 path: str 40 query_params: dict[str, str] 41 ip_address: str | None 42 body: dict[str, object] | None
Data for activity log.
def
sanitize_body(body: dict[str, object]) -> dict[str, object]:
45def sanitize_body(body: dict[str, object]) -> dict[str, object]: 46 """Redact sensitive fields from request body. 47 48 Returns: 49 Body dict with sensitive fields redacted. 50 """ 51 result: dict[str, object] = {} 52 for key, value in body.items(): 53 if key.lower() in SENSITIVE_FIELDS: 54 result[key] = "REDACTED" 55 else: 56 result[key] = value 57 return result
Redact sensitive fields from request body.
Returns:
Body dict with sensitive fields redacted.
async def
get_user_from_token( token: str) -> tuple[int, internal.queries.models.UserRole] | tuple[None, None]:
60async def get_user_from_token(token: str) -> tuple[int, UserRole] | tuple[None, None]: 61 """Get user_id and role from Bearer token. 62 63 Returns: 64 Tuple of (user_id, user_role) or (None, None) if not found. 65 """ 66 try: 67 async for conn in database_manager.get_connection(): 68 session = await TokenQuerier(conn).get_session_by_token(token=token) 69 if session: 70 return session.user_id, session.role 71 except SQLAlchemyError: 72 logger.exception("Database error fetching user from token") 73 return None, None
Get user_id and role from Bearer token.
Returns:
Tuple of (user_id, user_role) or (None, None) if not found.
76async def log_to_db(log_data: LogData) -> None: 77 """Log activity to database.""" 78 details: str | None = None 79 if log_data.body: 80 details = json.dumps({"body": log_data.body}) 81 82 try: 83 async for conn in database_manager.get_connection(): 84 await ActivityLogQuerier(conn).create_activity_log( 85 CreateActivityLogParams( 86 user_id=log_data.user_id, 87 action=f"{log_data.method} {log_data.path}", 88 details=details, 89 ip_address=log_data.ip_address or "", 90 ) 91 ) 92 await conn.commit() 93 except SQLAlchemyError: 94 logger.exception("Failed to log to database") 95 return
Log activity to database.
async def
log_request(request: starlette.requests.Request) -> None:
98async def log_request(request: Request) -> None: 99 """Log request details to activity_log table.""" 100 body = None 101 content_type = request.headers.get("content-type", "") 102 if ( 103 request.method in {"POST", "PATCH", "PUT"} 104 and "application/json" in content_type 105 ): 106 try: 107 raw_body = await request.json() 108 if isinstance(raw_body, dict): 109 body = sanitize_body(raw_body) 110 except json.JSONDecodeError, KeyError: 111 pass # no body or invalid json 112 113 user_id: int | None = None 114 user_role: UserRole | None = None 115 auth_header = request.headers.get("authorization") 116 is_bearer = auth_header is not None and auth_header.startswith("Bearer ") 117 if is_bearer: 118 auth_values = auth_header 119 if auth_values is None: 120 logger.exception("Failed to read the auth header.") 121 return 122 token = auth_values[7:] 123 result = await get_user_from_token(token) 124 user_id, user_role = result 125 126 should_log = user_id or not is_bearer 127 if should_log: 128 log_data = LogData( 129 user_id=user_id, 130 user_role=user_role, 131 method=request.method, 132 path=str(request.url.path), 133 query_params=dict(request.query_params), 134 ip_address=request.client.host if request.client else None, 135 body=body, 136 ) 137 await log_to_db(log_data)
Log request details to activity_log table.
def
hash_password(password: str) -> str:
140def hash_password(password: str) -> str: 141 """Secure password hashing. 142 143 Args: 144 password: plain text password 145 146 Returns: 147 hashed password with salt 148 """ 149 pasword_bytes = password.encode("utf-8") 150 salt = gensalt(rounds=12) 151 return hashpw(pasword_bytes, salt).decode("utf-8")
Secure password hashing.
Arguments:
- password: plain text password
Returns:
hashed password with salt
def
check_password(password: str, password_hash: str) -> bool:
155def check_password(password: str, password_hash: str) -> bool: 156 """Checks if password corresponds with given password hash. 157 158 Args: 159 password: plain text password 160 password_hash: hash that password gets check against 161 162 Returns: 163 if password is corresponds to a password hash 164 """ 165 password_bytes = password.encode("utf-8") 166 password_hash_bytes = password_hash.encode("utf-8") 167 return checkpw(password_bytes, password_hash_bytes)
Checks if password corresponds with given password hash.
Arguments:
- password: plain text password
- password_hash: hash that password gets check against
Returns:
if password is corresponds to a password hash
def
generate_token() -> str:
171def generate_token() -> str: 172 """Generates secure token for use anywhere on web. 173 174 Returns: 175 rendomly generated token as a string 176 """ 177 return token_urlsafe(32)
Generates secure token for use anywhere on web.
Returns:
rendomly generated token as a string
class
UpdatePasswordForm(pydantic.main.BaseModel):
180class UpdatePasswordForm(BaseModel): 181 """User form to update password.""" 182 183 old_password: SecretStr 184 new_password: SecretStr
User form to update password.
async def
update_pw( email: str, form: UpdatePasswordForm, conn: sqlalchemy.ext.asyncio.engine.AsyncConnection) -> internal.queries.user.UpdateUserPasswordRow:
187async def update_pw( 188 email: str, form: UpdatePasswordForm, conn: AsyncConnection 189) -> UpdateUserPasswordRow: 190 """Update user password with old password check. 191 192 Args: 193 email: user email 194 form: form with information to update password 195 conn: database connection 196 197 Returns: 198 updated user record 199 200 Raises: 201 HTTPException: if failed to perform update 202 """ 203 querier = UserQuerier(conn) 204 user = await querier.get_user_login(email=email) 205 if not user: 206 raise HTTPException( 207 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 208 detail="Failed to find user", 209 ) 210 is_valid = await asyncio.to_thread( 211 check_password, form.old_password.get_secret_value(), user.pw_hash 212 ) 213 if not is_valid: 214 raise HTTPException( 215 status_code=status.HTTP_403_FORBIDDEN, detail="Old password is incorrect" 216 ) 217 new_hashed_pw = await asyncio.to_thread( 218 hash_password, form.new_password.get_secret_value() 219 ) 220 user_updated = await querier.update_user_password( 221 UpdateUserPasswordParams(user_id=user.user_id, pw_hash=new_hashed_pw) 222 ) 223 if not user_updated: 224 raise HTTPException( 225 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 226 detail="Failed to update password", 227 ) 228 return user_updated
Update user password with old password check.
Arguments:
- email: user email
- form: form with information to update password
- conn: database connection
Returns:
updated user record
Raises:
- HTTPException: if failed to perform update
def
generate_claim_code(used_codes: list[str]) -> str:
231def generate_claim_code(used_codes: list[str]) -> str: 232 """Generate claim code. 233 234 Args: 235 used_codes: list of codes to avoid 236 237 Returns: 238 claim code 239 """ 240 while True: 241 alphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" 242 raw_code = "".join(choice(alphabet) for _ in range(4)) 243 code = f"{raw_code[:2]}-{raw_code[2:]}" 244 if code not in used_codes: 245 return code
Generate claim code.
Arguments:
- used_codes: list of codes to avoid
Returns:
claim code