database.init_db

Initialisation and deinitialisation of the database.

  1"""Initialisation and deinitialisation of the database."""
  2
  3from logging import Logger, getLogger
  4from logging.config import dictConfig
  5from pathlib import Path
  6from re import IGNORECASE, search, split
  7from string.templatelib import Template
  8
  9from internal.settings.env import database_settings
 10from psycopg import Connection, Error, connect
 11from uvicorn.config import LOGGING_CONFIG
 12
 13from database.db_constants import ALLERGENS, ANALYTICS_GRAPHS_TYPES, BADGES, CATEGORIES
 14
 15dictConfig(LOGGING_CONFIG)
 16SCHEMA_PATH = Path("database/migrations/schema.sql")
 17
 18
 19def load_queries() -> list[str]:
 20    """Loads table creation queries from schema.
 21
 22    Returns:
 23      list of separate queries
 24    """
 25    content = SCHEMA_PATH.read_text(encoding="utf-8")
 26    raw_queries = split(r"\n\s*\n", content)
 27    return [q.strip() for q in raw_queries if q.strip()]
 28
 29
 30# Connect to the database
 31def get_db_connection(logger: Logger) -> Connection | None:
 32    """Opens connection to the database.
 33
 34    Args:
 35      logger: configured logger
 36
 37    Returns:
 38      Opened connection
 39    """
 40    try:
 41        connection = connect(
 42            host=database_settings.host,
 43            port=database_settings.port,
 44            dbname=database_settings.database,
 45            user=database_settings.username,
 46            password=database_settings.password,
 47        )
 48        if connection.closed == 0:
 49            logger.info("Connection to database successful.")
 50            return connection
 51        logger.error("Connection is closed")
 52        return None
 53    except Error as e:
 54        logger.error(f"Error while connecting to Postgres: {e}")
 55        return None
 56
 57
 58# Get entity name and type
 59def get_type_and_name(sql: str) -> tuple[str, str] | None:
 60    """Extracts name and type of the entity from the query.
 61
 62    Args:
 63        sql: sql string
 64
 65    Returns:
 66      tuple with table type and name
 67    """
 68    table_match = search(
 69        r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?`?(\w+)`?", sql, IGNORECASE
 70    )
 71    if table_match:
 72        return ("TABLE", table_match.group(1))
 73
 74    enum_match = search(r"CREATE\s+TYPE\s+`?(\w+)`?\s+AS\s+ENUM", sql, IGNORECASE)
 75    if enum_match:
 76        return ("ENUM", enum_match.group(1))
 77    return None
 78
 79
 80def create_all_tables(
 81    logger: Logger, table_queries: list[str], conn: Connection
 82) -> None:
 83    """Execute sql for each table.
 84
 85    Args:
 86      logger: configured logger
 87      table_queries: list of the queries for table craetion
 88      conn: connection to the database
 89    """
 90    for sql in table_queries:
 91        type_name = get_type_and_name(sql)
 92        if not type_name:
 93            logger.error("unrecognised entity type")
 94            continue
 95        try:
 96            with conn.cursor() as cursor:
 97                cursor.execute(Template(sql))
 98                logger.info(f"{type_name[0].lower()} {type_name[1]} created.")
 99        except Error as err:
100            logger.error(f"Failed to create {type_name[0]} {type_name[1]}: {err}")
101        conn.commit()
102
103
104def show_all_tables(logger: Logger, conn: Connection) -> None:
105    """Show all tables located in the database.
106
107    Args:
108      logger: configured logger
109      conn: connection to the database
110    """
111    with conn.cursor() as cursor:
112        cursor.execute("""
113            SELECT table_name
114            FROM information_schema.tables
115            WHERE table_schema NOT IN ('pg_catalog', 'information_schema');
116        """)
117        tables = cursor.fetchall()
118        logger.info("Current tables in the database:")
119        for table in tables:
120            logger.info(table)
121
122
123def drop_all_tables(logger: Logger, table_queries: list[str], conn: Connection) -> None:
124    """Drops all created tables from database.
125
126    Args:
127      logger: configured logger
128      table_queries: list of the queries for table craetion
129      conn: connection to the database
130    """
131    table_queries.reverse()
132    for sql in table_queries:
133        type_name = get_type_and_name(sql)
134        if not type_name:
135            logger.error("unrecognised entity type")
136            continue
137        if type_name[0] == "ENUM":
138            logger.warning(f"skipping {type_name[1]} as enum")
139            continue
140        with conn.cursor() as cursor:
141            cursor.execute(Template(f"DROP {type_name[0]} IF EXISTS {type_name[1]};"))
142            logger.info(f"table {type_name[1]} removed.")
143    conn.commit()
144
145
146def seed_static_data(logger: Logger, conn: Connection) -> None:
147    """Inserts categories and allergens into the db."""
148    try:
149        # Seed Categories
150        logger.info("inserting categories")
151        for category in CATEGORIES:
152            conn.execute(
153                "INSERT INTO category "
154                "(category_id, category_name, category_coefficient)"
155                "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
156                (category["cat_id"], category["name"], category["coefficient"]),
157            )
158
159        # Seed Allergens
160        logger.info("inserting allergens")
161        for all_id, name in ALLERGENS.items():
162            conn.execute(
163                "INSERT INTO allergens (allergen_id, allergen_name)"
164                "VALUES (%s, %s) ON CONFLICT DO NOTHING",
165                (all_id, name),
166            )
167        # Seed Bundles
168        logger.info("inserting badges")
169        for badge in BADGES:
170            conn.execute(
171                "INSERT INTO badges (badge_id, name, description)"
172                "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
173                (badge["badge_id"], badge["name"], badge["description"]),
174            )
175        logger.info("inserting analytics graphs types")
176        for graphs_types in ANALYTICS_GRAPHS_TYPES:
177            conn.execute(
178                "INSERT INTO analytics_graphs_types "
179                "(graph_type_id, chart_type, graph_summary, x_axis_label, y_axis_label)"
180                "VALUES (%s, %s, %s, %s, %s) ON CONFLICT DO NOTHING",
181                (
182                    graphs_types["graph_type_id"],
183                    graphs_types["chart_type"],
184                    graphs_types["graph_summary"],
185                    graphs_types["x_axis_label"],
186                    graphs_types["y_axis_label"],
187                ),
188            )
189        conn.commit()
190        logger.info("finished inserting data")
191    except Error as e:
192        logger.error(f"Error seeding data: {e}")
193        conn.rollback()
194
195
196def main() -> None:
197    """Entrypoint for database management script."""
198    dictConfig(LOGGING_CONFIG)
199    logger = getLogger("uvicorn.info")
200    table_queries = load_queries()
201    conn = get_db_connection(logger)
202    if not conn:
203        return
204    option = 0
205    while option not in {1, 2, 3, 4}:
206        message = (
207            "(1) Init all table, (2) Show all tables, (3) Drop all tables, (4) exit: "
208        )
209        option = int(input(message))
210    if input(f"confirm option {option} (y/N): ") not in {"y", "Y", "yes"}:
211        option = 0
212    match option:
213        case 1:
214            # Create all tables
215            create_all_tables(logger, table_queries, conn)
216            seed_static_data(logger, conn)
217        case 2:
218            # Show all tables
219            show_all_tables(logger, conn)
220        case 3:
221            # Drop all tables
222            drop_all_tables(logger, table_queries, conn)
223        case 4:
224            logger.info("shutting down: database connection closed")
225    conn.close()
226    logger.info("database connection closed.")
227
228
229if __name__ == "__main__":
230    main()
SCHEMA_PATH = PosixPath('database/migrations/schema.sql')
def load_queries() -> list[str]:
20def load_queries() -> list[str]:
21    """Loads table creation queries from schema.
22
23    Returns:
24      list of separate queries
25    """
26    content = SCHEMA_PATH.read_text(encoding="utf-8")
27    raw_queries = split(r"\n\s*\n", content)
28    return [q.strip() for q in raw_queries if q.strip()]

Loads table creation queries from schema.

Returns:

list of separate queries

def get_db_connection(logger: logging.Logger) -> psycopg.Connection | None:
32def get_db_connection(logger: Logger) -> Connection | None:
33    """Opens connection to the database.
34
35    Args:
36      logger: configured logger
37
38    Returns:
39      Opened connection
40    """
41    try:
42        connection = connect(
43            host=database_settings.host,
44            port=database_settings.port,
45            dbname=database_settings.database,
46            user=database_settings.username,
47            password=database_settings.password,
48        )
49        if connection.closed == 0:
50            logger.info("Connection to database successful.")
51            return connection
52        logger.error("Connection is closed")
53        return None
54    except Error as e:
55        logger.error(f"Error while connecting to Postgres: {e}")
56        return None

Opens connection to the database.

Arguments:
  • logger: configured logger
Returns:

Opened connection

def get_type_and_name(sql: str) -> tuple[str, str] | None:
60def get_type_and_name(sql: str) -> tuple[str, str] | None:
61    """Extracts name and type of the entity from the query.
62
63    Args:
64        sql: sql string
65
66    Returns:
67      tuple with table type and name
68    """
69    table_match = search(
70        r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?`?(\w+)`?", sql, IGNORECASE
71    )
72    if table_match:
73        return ("TABLE", table_match.group(1))
74
75    enum_match = search(r"CREATE\s+TYPE\s+`?(\w+)`?\s+AS\s+ENUM", sql, IGNORECASE)
76    if enum_match:
77        return ("ENUM", enum_match.group(1))
78    return None

Extracts name and type of the entity from the query.

Arguments:
  • sql: sql string
Returns:

tuple with table type and name

def create_all_tables( logger: logging.Logger, table_queries: list[str], conn: psycopg.Connection) -> None:
 81def create_all_tables(
 82    logger: Logger, table_queries: list[str], conn: Connection
 83) -> None:
 84    """Execute sql for each table.
 85
 86    Args:
 87      logger: configured logger
 88      table_queries: list of the queries for table craetion
 89      conn: connection to the database
 90    """
 91    for sql in table_queries:
 92        type_name = get_type_and_name(sql)
 93        if not type_name:
 94            logger.error("unrecognised entity type")
 95            continue
 96        try:
 97            with conn.cursor() as cursor:
 98                cursor.execute(Template(sql))
 99                logger.info(f"{type_name[0].lower()} {type_name[1]} created.")
100        except Error as err:
101            logger.error(f"Failed to create {type_name[0]} {type_name[1]}: {err}")
102        conn.commit()

Execute sql for each table.

Arguments:
  • logger: configured logger
  • table_queries: list of the queries for table craetion
  • conn: connection to the database
def show_all_tables(logger: logging.Logger, conn: psycopg.Connection) -> None:
105def show_all_tables(logger: Logger, conn: Connection) -> None:
106    """Show all tables located in the database.
107
108    Args:
109      logger: configured logger
110      conn: connection to the database
111    """
112    with conn.cursor() as cursor:
113        cursor.execute("""
114            SELECT table_name
115            FROM information_schema.tables
116            WHERE table_schema NOT IN ('pg_catalog', 'information_schema');
117        """)
118        tables = cursor.fetchall()
119        logger.info("Current tables in the database:")
120        for table in tables:
121            logger.info(table)

Show all tables located in the database.

Arguments:
  • logger: configured logger
  • conn: connection to the database
def drop_all_tables( logger: logging.Logger, table_queries: list[str], conn: psycopg.Connection) -> None:
124def drop_all_tables(logger: Logger, table_queries: list[str], conn: Connection) -> None:
125    """Drops all created tables from database.
126
127    Args:
128      logger: configured logger
129      table_queries: list of the queries for table craetion
130      conn: connection to the database
131    """
132    table_queries.reverse()
133    for sql in table_queries:
134        type_name = get_type_and_name(sql)
135        if not type_name:
136            logger.error("unrecognised entity type")
137            continue
138        if type_name[0] == "ENUM":
139            logger.warning(f"skipping {type_name[1]} as enum")
140            continue
141        with conn.cursor() as cursor:
142            cursor.execute(Template(f"DROP {type_name[0]} IF EXISTS {type_name[1]};"))
143            logger.info(f"table {type_name[1]} removed.")
144    conn.commit()

Drops all created tables from database.

Arguments:
  • logger: configured logger
  • table_queries: list of the queries for table craetion
  • conn: connection to the database
def seed_static_data(logger: logging.Logger, conn: psycopg.Connection) -> None:
147def seed_static_data(logger: Logger, conn: Connection) -> None:
148    """Inserts categories and allergens into the db."""
149    try:
150        # Seed Categories
151        logger.info("inserting categories")
152        for category in CATEGORIES:
153            conn.execute(
154                "INSERT INTO category "
155                "(category_id, category_name, category_coefficient)"
156                "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
157                (category["cat_id"], category["name"], category["coefficient"]),
158            )
159
160        # Seed Allergens
161        logger.info("inserting allergens")
162        for all_id, name in ALLERGENS.items():
163            conn.execute(
164                "INSERT INTO allergens (allergen_id, allergen_name)"
165                "VALUES (%s, %s) ON CONFLICT DO NOTHING",
166                (all_id, name),
167            )
168        # Seed Bundles
169        logger.info("inserting badges")
170        for badge in BADGES:
171            conn.execute(
172                "INSERT INTO badges (badge_id, name, description)"
173                "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
174                (badge["badge_id"], badge["name"], badge["description"]),
175            )
176        logger.info("inserting analytics graphs types")
177        for graphs_types in ANALYTICS_GRAPHS_TYPES:
178            conn.execute(
179                "INSERT INTO analytics_graphs_types "
180                "(graph_type_id, chart_type, graph_summary, x_axis_label, y_axis_label)"
181                "VALUES (%s, %s, %s, %s, %s) ON CONFLICT DO NOTHING",
182                (
183                    graphs_types["graph_type_id"],
184                    graphs_types["chart_type"],
185                    graphs_types["graph_summary"],
186                    graphs_types["x_axis_label"],
187                    graphs_types["y_axis_label"],
188                ),
189            )
190        conn.commit()
191        logger.info("finished inserting data")
192    except Error as e:
193        logger.error(f"Error seeding data: {e}")
194        conn.rollback()

Inserts categories and allergens into the db.

def main() -> None:
197def main() -> None:
198    """Entrypoint for database management script."""
199    dictConfig(LOGGING_CONFIG)
200    logger = getLogger("uvicorn.info")
201    table_queries = load_queries()
202    conn = get_db_connection(logger)
203    if not conn:
204        return
205    option = 0
206    while option not in {1, 2, 3, 4}:
207        message = (
208            "(1) Init all table, (2) Show all tables, (3) Drop all tables, (4) exit: "
209        )
210        option = int(input(message))
211    if input(f"confirm option {option} (y/N): ") not in {"y", "Y", "yes"}:
212        option = 0
213    match option:
214        case 1:
215            # Create all tables
216            create_all_tables(logger, table_queries, conn)
217            seed_static_data(logger, conn)
218        case 2:
219            # Show all tables
220            show_all_tables(logger, conn)
221        case 3:
222            # Drop all tables
223            drop_all_tables(logger, table_queries, conn)
224        case 4:
225            logger.info("shutting down: database connection closed")
226    conn.close()
227    logger.info("database connection closed.")

Entrypoint for database management script.