internal.analytics.forecast
Forecasting logic for demand prediction of upcoming bundles.
1"""Forecasting logic for demand prediction of upcoming bundles.""" 2 3import datetime 4from collections import defaultdict 5from decimal import Decimal 6from statistics import mean 7from typing import Protocol 8 9import numpy as np 10from internal.queries.models import DayOfWeek, ForecastInput, WeatherFlag 11from lightgbm import LGBMRegressor 12from pydantic import BaseModel 13 14# --- Constants --- 15 16# The minimum number of rows required before we trust the ML model. 17# Below this threshold we fall back to the similarity-weighted average, which 18# is more stable on very small datasets. 19_MIN_ML_SAMPLES: int = 10 20 21# Map each enum member to a stable integer index based on definition order. 22# LightGBM and numpy can only work with numbers, not enum values, so these 23# dictionaries are used inside _encode() to convert before training/predicting. 24_DAY_INDEX: dict[DayOfWeek, int] = {} 25for i, d in enumerate(DayOfWeek): 26 _DAY_INDEX[d] = i 27 28_WEATHER_INDEX: dict[WeatherFlag, int] = {} 29for i, w in enumerate(WeatherFlag): 30 _WEATHER_INDEX[w] = i 31 32# Cold-start defaults, returned when a seller has zero historical rows for a 33# category. Values are intentionally conservative: predict low demand and low 34# confidence so the seller is not misled by a guess with no evidence behind it. 35_COLD_RESERVATIONS: int = 1 36_COLD_NO_SHOW_PROB: Decimal = Decimal("0.1500") 37_COLD_CONFIDENCE: Decimal = Decimal("0.0500") 38_COLD_RATIONALE: str = ( 39 "No historical data found for this seller and category. " 40 "Forecast is a conservative cold-start estimate." 41) 42 43# --- Models --- 44 45 46class ForecastQuery(BaseModel): 47 """The conditions of an upcoming bundle to forecast demand for.""" 48 49 bundle_id: int 50 seller_id: int 51 # A bundle can belong to more than one category. We forecast each category 52 # separately and average the results, so this is a list rather than a 53 # single int. See generate_forecast() for how the averaging works. 54 category_ids: list[int] 55 day_of_week: DayOfWeek 56 window_start: datetime.datetime 57 window_end: datetime.datetime 58 is_holiday: bool 59 temperature: Decimal 60 weather_flag: WeatherFlag 61 posted_qty: int 62 63 64class ForecastResult(BaseModel): 65 """Forecast prediction ready to be written to "forecast_output".""" 66 67 bundle_id: int 68 seller_id: int 69 window_start: datetime.datetime 70 predicted_sales: int 71 posted_qty: int 72 predicted_no_show_prob: Decimal 73 confidence: Decimal 74 rationale: str | None 75 76 77# --- Helpers --- 78 79 80class _HasFeatures(Protocol): 81 """Structural protocol shared by ForecastInput and ForecastQuery. 82 83 Both types carry the same six condition fields. Declaring this protocol 84 lets _encode() accept either type without them needing a shared base class. 85 """ 86 87 day_of_week: DayOfWeek 88 window_start: datetime.datetime 89 window_end: datetime.datetime 90 is_holiday: bool 91 temperature: Decimal 92 weather_flag: WeatherFlag 93 94 95def _encode(obj: _HasFeatures) -> list[float]: 96 """Encode any object matching _HasFeatures into a feature vector. 97 98 Feature order: [day_of_week, window_start, window_end, 99 is_holiday, temperature, weather_flag] 100 101 The order here should stay the same LightGBM trains on columns by position, 102 so the query vector must be built in exactly the same order as the 103 training matrix. 104 105 Returns: 106 A list of floats representing the encoded features. 107 """ 108 return [ 109 float(_DAY_INDEX[obj.day_of_week]), 110 # Convert datetime to decimal hour so 9:30 becomes 9.5, 14:15 becomes 111 # 14.25, etc. 112 obj.window_start.hour + obj.window_start.minute / 60.0, 113 obj.window_end.hour + obj.window_end.minute / 60.0, 114 # bool must be cast to float LightGBM expects all 115 # features to be the same numeric type. 116 float(obj.is_holiday), 117 float(obj.temperature), 118 float(_WEATHER_INDEX[obj.weather_flag]), 119 ] 120 121 122def _no_show_rate(row: ForecastInput) -> float | None: 123 """Return the observed no-show rate, or None when reservations == 0. 124 125 We return None rather than 0.0 when there were no reservations because 126 0.0 would imply a perfect attendance rate, which is misleading. 127 """ 128 if row.observed_reservations == 0: 129 return None 130 return row.observed_no_shows / row.observed_reservations 131 132 133def _confidence_from_n(n: int) -> float: 134 """Map sample count to a confidence score in range [0.05, 0.90]. 135 136 Uses the formula n / (n + 20) which grows quickly at first then levels 137 off, 10 rows gives 0.33, 20 rows gives 0.50, 80 rows gives 0.80. 138 Capped at 0.90 so confidence never reaches 1.0. 139 140 Returns: 141 A confidence score between 0.05 and 0.90. 142 """ 143 return min(0.90, n / (n + 20.0)) 144 145 146def _build_rationale( 147 n: int, 148 avg_reservations: float, 149 avg_no_show_rate: float, 150 method: str, 151 query: ForecastQuery, 152) -> str: 153 """Build a plain-English explanation of how the forecast was produced. 154 155 Outputs something like: 156 "Based on 15 past slots (Monday, 09:00-12:00, rainy, 8.5 degrees C) 157 using LightGBM: avg 6.3 reservations, 18.0% no-show rate." 158 159 Returns: 160 A string explanation of the forecast. 161 """ 162 # Only append the holiday note when relevant 163 # for the common case where it is not a public holiday. 164 holiday_note = " (public holiday)" if query.is_holiday else "" 165 166 window = ( 167 f"{query.window_start.strftime('%H:%M')}-{query.window_end.strftime('%H:%M')}" 168 ) 169 170 # Grammatically correct singular/plural so the output reads naturally. 171 slot_word = "slot" if n == 1 else "slots" 172 173 return ( 174 f"Based on {n} past {slot_word} " 175 f"({query.day_of_week.value}{holiday_note}, {window}, " 176 f"{query.weather_flag.value}, {float(query.temperature):.1f}°C) " 177 f"using {method}: " 178 f"avg {avg_reservations:.1f} reservations, " 179 # Multiply by 100 to display as a percentage rather than a decimal. 180 f"{avg_no_show_rate * 100:.1f}% no-show rate." 181 ) 182 183 184# --- Private prediction functions --- 185 186 187def _forecast_single_category( 188 subset: list[ForecastInput], query: ForecastQuery, bundle_id: int 189) -> ForecastResult: 190 """Run a forecast for a single category and return the raw result. 191 192 Expects "subset" to already be filtered to the relevant seller and 193 category. This function 194 can be reused by generate_seller_forecasts(), which pre-groups 195 history once and passes the correct slice directly. 196 197 Returns: 198 A ForecastResult with predictions for the category. 199 """ 200 n = len(subset) 201 202 # if no history exists for this category, return the 203 # cold-start defaults. 204 if n == 0: 205 return ForecastResult( 206 bundle_id=bundle_id, 207 seller_id=query.seller_id, 208 window_start=query.window_start, 209 predicted_sales=_COLD_RESERVATIONS, 210 posted_qty=query.posted_qty, 211 predicted_no_show_prob=_COLD_NO_SHOW_PROB, 212 confidence=_COLD_CONFIDENCE, 213 rationale=_COLD_RATIONALE, 214 ) 215 216 # Choose the prediction method based on how much data we have. 217 # LightGBM needs enough rows to build meaningful decision trees; 218 # the weighted average is more stable on very small samples. 219 if n >= _MIN_ML_SAMPLES: 220 pred_res, pred_ns = _predict_with_lgbm(subset, query) 221 method = "LightGBM" 222 else: 223 pred_res, pred_ns = _predict_weighted_avg(subset, query) 224 method = "similarity-weighted average" 225 226 # Round reservations to a whole number, you can't have half a reservation. 227 # max(0, ...) prevents a negative prediction reaching the database if the 228 # model produces one. 229 predicted_sales = max(0, round(pred_res)) 230 231 # Clip the no-show probability to a valid range. 232 predicted_no_show_prob = float(np.clip(pred_ns, 0.0, 1.0)) 233 234 # These averages are used only in the rationale string to give the seller 235 # context on what the prediction is based on. 236 avg_res = mean(r.observed_reservations for r in subset) 237 no_show_rates = [rate for r in subset if (rate := _no_show_rate(r)) is not None] 238 # Default to 0.0 if every historical row had zero reservations and produced 239 # no valid rate, mean() would raise a StatisticsError on an empty list. 240 avg_no_show = mean(no_show_rates) if no_show_rates else 0.0 241 242 return ForecastResult( 243 bundle_id=bundle_id, 244 seller_id=query.seller_id, 245 window_start=query.window_start, 246 predicted_sales=predicted_sales, 247 posted_qty=query.posted_qty, 248 # Convert via string to avoid floating point not being precise 249 # round() pins it to 4 decimal places to 250 # match the DECIMAL(5,4) column in forecast_output. 251 predicted_no_show_prob=Decimal(str(round(predicted_no_show_prob, 4))), 252 confidence=Decimal(str(round(_confidence_from_n(n), 4))), 253 rationale=_build_rationale(n, avg_res, avg_no_show, method, query), 254 ) 255 256 257def _average_category_results( 258 results: list[ForecastResult], bundle_id: int 259) -> ForecastResult: 260 """Average a list of per-category ForecastResults into one final result. 261 262 Used by both generate_forecast() and generate_seller_forecasts() so the 263 averaging logic lives in one place. If the 264 averaging behaviour ever changes, there is one place to update. 265 266 Returns: 267 A single ForecastResult with averaged predictions. 268 """ 269 avg_reservations = mean(r.predicted_sales for r in results) 270 # float() is needed here because predicted_no_show_prob and confidence are 271 # stored as Decimal on each result, converting to float first keeps things 272 # consistent before we round and convert back to Decimal at the end. 273 avg_no_show_prob = mean(float(r.predicted_no_show_prob) for r in results) 274 avg_confidence = mean(float(r.confidence) for r in results) 275 276 n_categories = len(results) 277 278 if n_categories == 1: 279 rationale = results[0].rationale 280 else: 281 rationale = ( 282 f"Multi-category forecast (averaged across {n_categories} categories)." 283 ) 284 285 first = results[0] 286 return ForecastResult( 287 bundle_id=bundle_id, 288 seller_id=first.seller_id, 289 window_start=first.window_start, 290 # Round to the nearest integer, fractional reservations are meaningless. 291 # max(0, ...) prevents a negative average. 292 predicted_sales=max(0, round(avg_reservations)), 293 posted_qty=first.posted_qty, 294 predicted_no_show_prob=Decimal(str(round(avg_no_show_prob, 4))), 295 confidence=Decimal(str(round(avg_confidence, 4))), 296 rationale=rationale, 297 ) 298 299 300# --- Public --- 301 302 303def generate_forecast( 304 history: list[ForecastInput], query: ForecastQuery, bundle_id: int 305) -> ForecastResult: 306 """Generate a reservation and no-show forecast for an upcoming bundle. 307 308 Runs a separate forecast for each category in query.category_ids, 309 then averages the results via _average_category_results. 310 311 Args: 312 history: All forecast_input rows for this seller. 313 query: The conditions of the upcoming bundle. 314 bundle_id: The bundle this forecast is attached to. 315 316 Returns: 317 A ForecastResult ready to insert into forecast_output. 318 """ 319 # Run one forecast per category. Each call receives only the rows that 320 # match both the seller and that specific category. For forecasting many 321 # bundles at once, use generate_seller_forecasts() instead, which 322 # pre-groups the history once upfront and avoids scanning the full list 323 # on every category call. 324 results = [ 325 _forecast_single_category( 326 [ 327 r 328 for r in history 329 if r.seller_id == query.seller_id and r.category_id == category_id 330 ], 331 query, 332 bundle_id, 333 ) 334 for category_id in query.category_ids 335 ] 336 return _average_category_results(results, bundle_id) 337 338 339def generate_seller_forecasts( 340 history: list[ForecastInput], bundles: list[tuple[int, ForecastQuery]] 341) -> list[ForecastResult]: 342 """Generate forecasts for all of a seller's upcoming bundles. 343 344 Pre-groups history by category once upfront. 345 346 Args: 347 history: All forecast_input rows for this seller. 348 bundles: A list of (bundle_id, ForecastQuery) pairs — one per 349 upcoming bundle to forecast. 350 351 Returns: 352 A list of ForecastResult objects ready to insert into 353 forecast_output, in the same order as bundles. 354 """ 355 # Build a dictionary keyed by category_id so each lookup is instant. 356 # defaultdict(list) means we never need to check whether a key exists 357 # before appending it creates an empty list automatically on first access. 358 history_by_category: dict[int, list[ForecastInput]] = defaultdict(list) 359 for row in history: 360 history_by_category[row.category_id].append(row) 361 362 results = [] 363 for bundle_id, query in bundles: 364 # For each bundle, gather one forecast per category using the 365 # pre-grouped dictionary. history_by_category[category_id] returns 366 # an empty list if that category has no history, 367 # which then uses the cold-start path inside _forecast_single_category. 368 per_category = [ 369 _forecast_single_category( 370 history_by_category[category_id], query, bundle_id 371 ) 372 for category_id in query.category_ids 373 ] 374 results.append(_average_category_results(per_category, bundle_id)) 375 376 # The order of results matches the order of bundles. 377 return results 378 379 380# --- Private prediction functions --- 381 382 383def _predict_with_lgbm( 384 subset: list[ForecastInput], query: ForecastQuery 385) -> tuple[float, float]: 386 """Train LightGBM regressors on subset and predict for query. 387 388 Trains two completely separate models, one for reservation count and one 389 for no-show rate. 390 391 n_estimators scales with sample count (50-200) to avoid overfitting 392 small datasets while still giving full power on larger ones. 393 verbose=-1 removes training logs on every call. 394 395 Returns: 396 (predicted_reservations, predicted_no_show_rate) 397 """ 398 # Build the feature matrix, each row is one historical bundle encoded 399 # as a vector of 6 numbers. (n_samples, 6). 400 x_all = np.array([_encode(r) for r in subset], dtype=float) 401 402 # Building the target vector for the reservations model. 403 # Shape: (n_samples,), one count per row. 404 y_res = np.array([r.observed_reservations for r in subset], dtype=float) 405 406 # Encode the upcoming bundle's conditions in the same format as the 407 # training rows. The extra [] wrapping produces shape (1, 6) rather than 408 # (6,) predict() needs 2D input. 409 q_vec = np.array([_encode(query)], dtype=float) 410 411 # Scale tree count to dataset size. More trees improve accuracy on larger 412 # datasets but risk overfitting when samples are scarce. 413 # never goes below 50 tress and above 200 414 n_estimators = len(subset) 415 n_estimators = max(n_estimators, 50) 416 n_estimators = min(n_estimators, 200) 417 418 # Train the reservations model and predict. 419 lgbm_res = LGBMRegressor(n_estimators=n_estimators, random_state=42, verbose=-1) 420 lgbm_res.fit(x_all, y_res) 421 # np.asarray().flat[0] safely extracts the scalar 422 pred_res = float(np.asarray(lgbm_res.predict(q_vec)).flat[0]) 423 424 # Calculate no-show rates for each historical row. Rows with zero 425 # reservations return None and must be excluded from training. 426 ns_rates = [_no_show_rate(r) for r in subset] 427 428 # Boolean masking is True for rows that have a valid no-show rate. 429 # Used to filter both the feature matrix and the target vector together 430 # so their lengths stay aligned. 431 valid_mask = np.array([v is not None for v in ns_rates]) 432 433 if valid_mask.any(): 434 y_ns = np.array([v for v in ns_rates if v is not None], dtype=float) 435 lgbm_ns = LGBMRegressor(n_estimators=n_estimators, random_state=42, verbose=-1) 436 # x_all[valid_mask] uses numpy boolean indexing to select only the 437 # rows that have a valid no-show rate 438 lgbm_ns.fit(x_all[valid_mask], y_ns) 439 pred_ns = float(np.asarray(lgbm_ns.predict(q_vec)).flat[0]) 440 else: 441 # if every historical row has zero reservations so no no-show model can 442 # be trained. Default to 0.0 to avoid crash. 443 pred_ns = 0.0 444 445 return pred_res, pred_ns 446 447 448def _predict_weighted_avg( 449 subset: list[ForecastInput], query: ForecastQuery 450) -> tuple[float, float]: 451 """Predict using a similarity-weighted average. 452 453 Each historical row earns a similarity score based on how closely its 454 conditions match the upcoming bundle. Scores are then normalised into 455 weights that sum to 1.0 and used to compute a weighted average of the 456 observed reservation counts and no-show rates. 457 458 Row weights: 459 460 - day_of_week exact match → +2.0 461 - window_start proximity (±2 h) → up to +1.5 462 - is_holiday exact match → +1.5 463 - weather_flag exact match → +1.0 464 - temperature proximity (±5°C) → up to +0.5 465 466 Returns: 467 (predicted_reservations, predicted_no_show_rate) 468 """ 469 # Pre-compute the query's start time and temperature as plain floats once 470 # so we don't repeat the same conversion inside the loop on every row. 471 q_start = query.window_start.hour + query.window_start.minute / 60.0 472 q_temp = float(query.temperature) 473 474 weights: list[float] = [] 475 res_vals: list[float] = [] 476 ns_vals: list[float] = [] 477 478 for row in subset: 479 score = 0.0 480 481 # Day of week is the strongest predictor, a Monday bundle's history 482 # is far more relevant to a future Monday than a Saturday's history. 483 if row.day_of_week == query.day_of_week: 484 score += 2.0 485 486 # Proximity score for time of day higher weight for closeness rather than 487 # requiring an exact match. Drops to 0 at 2 hours difference. 488 hour_diff = abs( 489 (row.window_start.hour + row.window_start.minute / 60.0) - q_start 490 ) 491 score += max(0.0, 1.5 * (1.0 - hour_diff / 2.0)) 492 493 # Holiday status is nearly as important as day of week, demand on a 494 # bank holiday behaves very differently to a normal weekday. 495 if row.is_holiday == query.is_holiday: 496 score += 1.5 497 498 # Exact weather match. Weighted lower than day and holiday because 499 # weather affects demand but is less dominant than day-level patterns. 500 if row.weather_flag == query.weather_flag: 501 score += 1.0 502 503 # Proximity score for temperature, full points for identical temp, 504 # dropping to 0 at 5°C difference. Lowest weighted feature because 505 # the weather flag already captures most of the temperature signal. 506 score += max(0.0, 0.5 * (1.0 - abs(float(row.temperature) - q_temp) / 5.0)) 507 508 # flooring weight just above zero. 509 # avoids division by 0 error. 510 weights.append(max(score, 1e-6)) 511 res_vals.append(float(row.observed_reservations)) 512 513 # Substitute 0.0 for rows with no calculable no-show rate (zero 514 # reservations). Unlike the LightGBM path which skips these rows, 515 # the weighted average needs all three lists to stay the same length 516 # so the dot product aligns correctly. 517 rate = _no_show_rate(row) 518 ns_vals.append(rate if rate is not None else 0.0) 519 520 # Normalise the raw scores so they sum to 1.0, converting them from 521 # arbitrary similarity points into proper proportional weights. 522 w_norm = np.array(weights, dtype=float) 523 w_norm /= w_norm.sum() 524 525 # Dot product: multiply each weight by its corresponding value and sum. 526 # A row with weight 0.4 and 8 reservations contributes 3.2 to the total. 527 return float(np.dot(w_norm, res_vals)), float(np.dot(w_norm, ns_vals))
class
ForecastQuery(pydantic.main.BaseModel):
47class ForecastQuery(BaseModel): 48 """The conditions of an upcoming bundle to forecast demand for.""" 49 50 bundle_id: int 51 seller_id: int 52 # A bundle can belong to more than one category. We forecast each category 53 # separately and average the results, so this is a list rather than a 54 # single int. See generate_forecast() for how the averaging works. 55 category_ids: list[int] 56 day_of_week: DayOfWeek 57 window_start: datetime.datetime 58 window_end: datetime.datetime 59 is_holiday: bool 60 temperature: Decimal 61 weather_flag: WeatherFlag 62 posted_qty: int
The conditions of an upcoming bundle to forecast demand for.
class
ForecastResult(pydantic.main.BaseModel):
65class ForecastResult(BaseModel): 66 """Forecast prediction ready to be written to "forecast_output".""" 67 68 bundle_id: int 69 seller_id: int 70 window_start: datetime.datetime 71 predicted_sales: int 72 posted_qty: int 73 predicted_no_show_prob: Decimal 74 confidence: Decimal 75 rationale: str | None
Forecast prediction ready to be written to "forecast_output".
def
generate_forecast( history: list[internal.queries.models.ForecastInput], query: ForecastQuery, bundle_id: int) -> ForecastResult:
304def generate_forecast( 305 history: list[ForecastInput], query: ForecastQuery, bundle_id: int 306) -> ForecastResult: 307 """Generate a reservation and no-show forecast for an upcoming bundle. 308 309 Runs a separate forecast for each category in query.category_ids, 310 then averages the results via _average_category_results. 311 312 Args: 313 history: All forecast_input rows for this seller. 314 query: The conditions of the upcoming bundle. 315 bundle_id: The bundle this forecast is attached to. 316 317 Returns: 318 A ForecastResult ready to insert into forecast_output. 319 """ 320 # Run one forecast per category. Each call receives only the rows that 321 # match both the seller and that specific category. For forecasting many 322 # bundles at once, use generate_seller_forecasts() instead, which 323 # pre-groups the history once upfront and avoids scanning the full list 324 # on every category call. 325 results = [ 326 _forecast_single_category( 327 [ 328 r 329 for r in history 330 if r.seller_id == query.seller_id and r.category_id == category_id 331 ], 332 query, 333 bundle_id, 334 ) 335 for category_id in query.category_ids 336 ] 337 return _average_category_results(results, bundle_id)
Generate a reservation and no-show forecast for an upcoming bundle.
Runs a separate forecast for each category in query.category_ids, then averages the results via _average_category_results.
Arguments:
- history: All forecast_input rows for this seller.
- query: The conditions of the upcoming bundle.
- bundle_id: The bundle this forecast is attached to.
Returns:
A ForecastResult ready to insert into forecast_output.
def
generate_seller_forecasts( history: list[internal.queries.models.ForecastInput], bundles: list[tuple[int, ForecastQuery]]) -> list[ForecastResult]:
340def generate_seller_forecasts( 341 history: list[ForecastInput], bundles: list[tuple[int, ForecastQuery]] 342) -> list[ForecastResult]: 343 """Generate forecasts for all of a seller's upcoming bundles. 344 345 Pre-groups history by category once upfront. 346 347 Args: 348 history: All forecast_input rows for this seller. 349 bundles: A list of (bundle_id, ForecastQuery) pairs — one per 350 upcoming bundle to forecast. 351 352 Returns: 353 A list of ForecastResult objects ready to insert into 354 forecast_output, in the same order as bundles. 355 """ 356 # Build a dictionary keyed by category_id so each lookup is instant. 357 # defaultdict(list) means we never need to check whether a key exists 358 # before appending it creates an empty list automatically on first access. 359 history_by_category: dict[int, list[ForecastInput]] = defaultdict(list) 360 for row in history: 361 history_by_category[row.category_id].append(row) 362 363 results = [] 364 for bundle_id, query in bundles: 365 # For each bundle, gather one forecast per category using the 366 # pre-grouped dictionary. history_by_category[category_id] returns 367 # an empty list if that category has no history, 368 # which then uses the cold-start path inside _forecast_single_category. 369 per_category = [ 370 _forecast_single_category( 371 history_by_category[category_id], query, bundle_id 372 ) 373 for category_id in query.category_ids 374 ] 375 results.append(_average_category_results(per_category, bundle_id)) 376 377 # The order of results matches the order of bundles. 378 return results
Generate forecasts for all of a seller's upcoming bundles.
Pre-groups history by category once upfront.
Arguments:
- history: All forecast_input rows for this seller.
- bundles: A list of (bundle_id, ForecastQuery) pairs — one per upcoming bundle to forecast.
Returns:
A list of ForecastResult objects ready to insert into forecast_output, in the same order as bundles.