internal.auth.security

Sensitive information manimulation.

  1"""Sensitive information manimulation."""
  2
  3import asyncio
  4import json
  5from secrets import choice, token_urlsafe
  6
  7from bcrypt import checkpw, gensalt, hashpw
  8from fastapi import HTTPException, Request, status
  9from pydantic import BaseModel, SecretStr
 10from sqlalchemy.exc import SQLAlchemyError
 11from sqlalchemy.ext.asyncio import AsyncConnection
 12
 13from internal.database.manager import database_manager
 14from internal.logger.logger import logger
 15from internal.queries.activity_log import AsyncQuerier as ActivityLogQuerier
 16from internal.queries.activity_log import CreateActivityLogParams
 17from internal.queries.models import UserRole
 18from internal.queries.token import AsyncQuerier as TokenQuerier
 19from internal.queries.user import AsyncQuerier as UserQuerier
 20from internal.queries.user import UpdateUserPasswordParams, UpdateUserPasswordRow
 21
 22SENSITIVE_FIELDS = {
 23    "password",
 24    "pw_hash",
 25    "token",
 26    "claim_code",
 27    "email",
 28    "new_password",
 29}
 30
 31
 32class LogData(BaseModel):
 33    """Data for activity log."""
 34
 35    user_id: int | None
 36    user_role: UserRole | None
 37    method: str
 38    path: str
 39    query_params: dict[str, str]
 40    ip_address: str | None
 41    body: dict[str, object] | None
 42
 43
 44def sanitize_body(body: dict[str, object]) -> dict[str, object]:
 45    """Redact sensitive fields from request body.
 46
 47    Returns:
 48        Body dict with sensitive fields redacted.
 49    """
 50    result: dict[str, object] = {}
 51    for key, value in body.items():
 52        if key.lower() in SENSITIVE_FIELDS:
 53            result[key] = "REDACTED"
 54        else:
 55            result[key] = value
 56    return result
 57
 58
 59async def get_user_from_token(token: str) -> tuple[int, UserRole] | tuple[None, None]:
 60    """Get user_id and role from Bearer token.
 61
 62    Returns:
 63        Tuple of (user_id, user_role) or (None, None) if not found.
 64    """
 65    try:
 66        async for conn in database_manager.get_connection():
 67            session = await TokenQuerier(conn).get_session_by_token(token=token)
 68            if session:
 69                return session.user_id, session.role
 70    except SQLAlchemyError:
 71        logger.exception("Database error fetching user from token")
 72    return None, None
 73
 74
 75async def log_to_db(log_data: LogData) -> None:
 76    """Log activity to database."""
 77    details: str | None = None
 78    if log_data.body:
 79        details = json.dumps({"body": log_data.body})
 80
 81    try:
 82        async for conn in database_manager.get_connection():
 83            await ActivityLogQuerier(conn).create_activity_log(
 84                CreateActivityLogParams(
 85                    user_id=log_data.user_id,
 86                    action=f"{log_data.method} {log_data.path}",
 87                    details=details,
 88                    ip_address=log_data.ip_address or "",
 89                )
 90            )
 91            await conn.commit()
 92    except SQLAlchemyError:
 93        logger.exception("Failed to log to database")
 94        return
 95
 96
 97async def log_request(request: Request) -> None:
 98    """Log request details to activity_log table."""
 99    body = None
100    content_type = request.headers.get("content-type", "")
101    if (
102        request.method in {"POST", "PATCH", "PUT"}
103        and "application/json" in content_type
104    ):
105        try:
106            raw_body = await request.json()
107            if isinstance(raw_body, dict):
108                body = sanitize_body(raw_body)
109        except json.JSONDecodeError, KeyError:
110            pass  # no body or invalid json
111
112    user_id: int | None = None
113    user_role: UserRole | None = None
114    auth_header = request.headers.get("authorization")
115    is_bearer = auth_header is not None and auth_header.startswith("Bearer ")
116    if is_bearer:
117        auth_values = auth_header
118        if auth_values is None:
119            logger.exception("Failed to read the auth header.")
120            return
121        token = auth_values[7:]
122        result = await get_user_from_token(token)
123        user_id, user_role = result
124
125    should_log = user_id or not is_bearer
126    if should_log:
127        log_data = LogData(
128            user_id=user_id,
129            user_role=user_role,
130            method=request.method,
131            path=str(request.url.path),
132            query_params=dict(request.query_params),
133            ip_address=request.client.host if request.client else None,
134            body=body,
135        )
136        await log_to_db(log_data)
137
138
139def hash_password(password: str) -> str:
140    """Secure password hashing.
141
142    Args:
143      password: plain text password
144
145    Returns:
146      hashed password with salt
147    """
148    pasword_bytes = password.encode("utf-8")
149    salt = gensalt(rounds=12)
150    return hashpw(pasword_bytes, salt).decode("utf-8")
151
152
153# check password against the hash
154def check_password(password: str, password_hash: str) -> bool:
155    """Checks if password corresponds with given password hash.
156
157    Args:
158      password: plain text password
159      password_hash: hash that password gets check against
160
161    Returns:
162      if password is corresponds to a password hash
163    """
164    password_bytes = password.encode("utf-8")
165    password_hash_bytes = password_hash.encode("utf-8")
166    return checkpw(password_bytes, password_hash_bytes)
167
168
169# generate url safe 256 bits
170def generate_token() -> str:
171    """Generates secure token for use anywhere on web.
172
173    Returns:
174      rendomly generated token as a string
175    """
176    return token_urlsafe(32)
177
178
179class UpdatePasswordForm(BaseModel):
180    """User form to update password."""
181
182    old_password: SecretStr
183    new_password: SecretStr
184
185
186async def update_pw(
187    email: str, form: UpdatePasswordForm, conn: AsyncConnection
188) -> UpdateUserPasswordRow:
189    """Update user password with old password check.
190
191    Args:
192      email: user email
193      form: form with information to update password
194      conn: database connection
195
196    Returns:
197      updated user record
198
199    Raises:
200      HTTPException: if failed to perform update
201    """
202    querier = UserQuerier(conn)
203    user = await querier.get_user_login(email=email)
204    if not user:
205        raise HTTPException(
206            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
207            detail="Failed to find user",
208        )
209    is_valid = await asyncio.to_thread(
210        check_password, form.old_password.get_secret_value(), user.pw_hash
211    )
212    if not is_valid:
213        raise HTTPException(
214            status_code=status.HTTP_403_FORBIDDEN, detail="Old password is incorrect"
215        )
216    new_hashed_pw = await asyncio.to_thread(
217        hash_password, form.new_password.get_secret_value()
218    )
219    user_updated = await querier.update_user_password(
220        UpdateUserPasswordParams(user_id=user.user_id, pw_hash=new_hashed_pw)
221    )
222    if not user_updated:
223        raise HTTPException(
224            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
225            detail="Failed to update password",
226        )
227    return user_updated
228
229
230def generate_claim_code(used_codes: list[str]) -> str:
231    """Generate claim code.
232
233    Args:
234        used_codes: list of codes to avoid
235
236    Returns:
237        claim code
238    """
239    while True:
240        alphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
241        raw_code = "".join(choice(alphabet) for _ in range(4))
242        code = f"{raw_code[:2]}-{raw_code[2:]}"
243        if code not in used_codes:
244            return code
SENSITIVE_FIELDS = {'pw_hash', 'new_password', 'claim_code', 'token', 'email', 'password'}
class LogData(pydantic.main.BaseModel):
33class LogData(BaseModel):
34    """Data for activity log."""
35
36    user_id: int | None
37    user_role: UserRole | None
38    method: str
39    path: str
40    query_params: dict[str, str]
41    ip_address: str | None
42    body: dict[str, object] | None

Data for activity log.

user_id: int | None = PydanticUndefined
user_role: internal.queries.models.UserRole | None = PydanticUndefined
method: str = PydanticUndefined
path: str = PydanticUndefined
query_params: dict[str, str] = PydanticUndefined
ip_address: str | None = PydanticUndefined
body: dict[str, object] | None = PydanticUndefined
def sanitize_body(body: dict[str, object]) -> dict[str, object]:
45def sanitize_body(body: dict[str, object]) -> dict[str, object]:
46    """Redact sensitive fields from request body.
47
48    Returns:
49        Body dict with sensitive fields redacted.
50    """
51    result: dict[str, object] = {}
52    for key, value in body.items():
53        if key.lower() in SENSITIVE_FIELDS:
54            result[key] = "REDACTED"
55        else:
56            result[key] = value
57    return result

Redact sensitive fields from request body.

Returns:

Body dict with sensitive fields redacted.

async def get_user_from_token( token: str) -> tuple[int, internal.queries.models.UserRole] | tuple[None, None]:
60async def get_user_from_token(token: str) -> tuple[int, UserRole] | tuple[None, None]:
61    """Get user_id and role from Bearer token.
62
63    Returns:
64        Tuple of (user_id, user_role) or (None, None) if not found.
65    """
66    try:
67        async for conn in database_manager.get_connection():
68            session = await TokenQuerier(conn).get_session_by_token(token=token)
69            if session:
70                return session.user_id, session.role
71    except SQLAlchemyError:
72        logger.exception("Database error fetching user from token")
73    return None, None

Get user_id and role from Bearer token.

Returns:

Tuple of (user_id, user_role) or (None, None) if not found.

async def log_to_db(log_data: LogData) -> None:
76async def log_to_db(log_data: LogData) -> None:
77    """Log activity to database."""
78    details: str | None = None
79    if log_data.body:
80        details = json.dumps({"body": log_data.body})
81
82    try:
83        async for conn in database_manager.get_connection():
84            await ActivityLogQuerier(conn).create_activity_log(
85                CreateActivityLogParams(
86                    user_id=log_data.user_id,
87                    action=f"{log_data.method} {log_data.path}",
88                    details=details,
89                    ip_address=log_data.ip_address or "",
90                )
91            )
92            await conn.commit()
93    except SQLAlchemyError:
94        logger.exception("Failed to log to database")
95        return

Log activity to database.

async def log_request(request: starlette.requests.Request) -> None:
 98async def log_request(request: Request) -> None:
 99    """Log request details to activity_log table."""
100    body = None
101    content_type = request.headers.get("content-type", "")
102    if (
103        request.method in {"POST", "PATCH", "PUT"}
104        and "application/json" in content_type
105    ):
106        try:
107            raw_body = await request.json()
108            if isinstance(raw_body, dict):
109                body = sanitize_body(raw_body)
110        except json.JSONDecodeError, KeyError:
111            pass  # no body or invalid json
112
113    user_id: int | None = None
114    user_role: UserRole | None = None
115    auth_header = request.headers.get("authorization")
116    is_bearer = auth_header is not None and auth_header.startswith("Bearer ")
117    if is_bearer:
118        auth_values = auth_header
119        if auth_values is None:
120            logger.exception("Failed to read the auth header.")
121            return
122        token = auth_values[7:]
123        result = await get_user_from_token(token)
124        user_id, user_role = result
125
126    should_log = user_id or not is_bearer
127    if should_log:
128        log_data = LogData(
129            user_id=user_id,
130            user_role=user_role,
131            method=request.method,
132            path=str(request.url.path),
133            query_params=dict(request.query_params),
134            ip_address=request.client.host if request.client else None,
135            body=body,
136        )
137        await log_to_db(log_data)

Log request details to activity_log table.

def hash_password(password: str) -> str:
140def hash_password(password: str) -> str:
141    """Secure password hashing.
142
143    Args:
144      password: plain text password
145
146    Returns:
147      hashed password with salt
148    """
149    pasword_bytes = password.encode("utf-8")
150    salt = gensalt(rounds=12)
151    return hashpw(pasword_bytes, salt).decode("utf-8")

Secure password hashing.

Arguments:
  • password: plain text password
Returns:

hashed password with salt

def check_password(password: str, password_hash: str) -> bool:
155def check_password(password: str, password_hash: str) -> bool:
156    """Checks if password corresponds with given password hash.
157
158    Args:
159      password: plain text password
160      password_hash: hash that password gets check against
161
162    Returns:
163      if password is corresponds to a password hash
164    """
165    password_bytes = password.encode("utf-8")
166    password_hash_bytes = password_hash.encode("utf-8")
167    return checkpw(password_bytes, password_hash_bytes)

Checks if password corresponds with given password hash.

Arguments:
  • password: plain text password
  • password_hash: hash that password gets check against
Returns:

if password is corresponds to a password hash

def generate_token() -> str:
171def generate_token() -> str:
172    """Generates secure token for use anywhere on web.
173
174    Returns:
175      rendomly generated token as a string
176    """
177    return token_urlsafe(32)

Generates secure token for use anywhere on web.

Returns:

rendomly generated token as a string

class UpdatePasswordForm(pydantic.main.BaseModel):
180class UpdatePasswordForm(BaseModel):
181    """User form to update password."""
182
183    old_password: SecretStr
184    new_password: SecretStr

User form to update password.

old_password: pydantic.types.SecretStr = PydanticUndefined
new_password: pydantic.types.SecretStr = PydanticUndefined
async def update_pw( email: str, form: UpdatePasswordForm, conn: sqlalchemy.ext.asyncio.engine.AsyncConnection) -> internal.queries.user.UpdateUserPasswordRow:
187async def update_pw(
188    email: str, form: UpdatePasswordForm, conn: AsyncConnection
189) -> UpdateUserPasswordRow:
190    """Update user password with old password check.
191
192    Args:
193      email: user email
194      form: form with information to update password
195      conn: database connection
196
197    Returns:
198      updated user record
199
200    Raises:
201      HTTPException: if failed to perform update
202    """
203    querier = UserQuerier(conn)
204    user = await querier.get_user_login(email=email)
205    if not user:
206        raise HTTPException(
207            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
208            detail="Failed to find user",
209        )
210    is_valid = await asyncio.to_thread(
211        check_password, form.old_password.get_secret_value(), user.pw_hash
212    )
213    if not is_valid:
214        raise HTTPException(
215            status_code=status.HTTP_403_FORBIDDEN, detail="Old password is incorrect"
216        )
217    new_hashed_pw = await asyncio.to_thread(
218        hash_password, form.new_password.get_secret_value()
219    )
220    user_updated = await querier.update_user_password(
221        UpdateUserPasswordParams(user_id=user.user_id, pw_hash=new_hashed_pw)
222    )
223    if not user_updated:
224        raise HTTPException(
225            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
226            detail="Failed to update password",
227        )
228    return user_updated

Update user password with old password check.

Arguments:
  • email: user email
  • form: form with information to update password
  • conn: database connection
Returns:

updated user record

Raises:
  • HTTPException: if failed to perform update
def generate_claim_code(used_codes: list[str]) -> str:
231def generate_claim_code(used_codes: list[str]) -> str:
232    """Generate claim code.
233
234    Args:
235        used_codes: list of codes to avoid
236
237    Returns:
238        claim code
239    """
240    while True:
241        alphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
242        raw_code = "".join(choice(alphabet) for _ in range(4))
243        code = f"{raw_code[:2]}-{raw_code[2:]}"
244        if code not in used_codes:
245            return code

Generate claim code.

Arguments:
  • used_codes: list of codes to avoid
Returns:

claim code