diff --git a/apiserver/src/core/db.py b/apiserver/src/core/db.py index 5067066..04723a5 100644 --- a/apiserver/src/core/db.py +++ b/apiserver/src/core/db.py @@ -1,14 +1,15 @@ import sqlite3 -from ..core.settings import get_data_path +from .settings import get_data_path, get_config def create_session(): path = get_data_path() / "./tam.db" - conn = sqlite3.connect(path) + conn = sqlite3.connect(path, timeout=10.0) cur = conn.cursor() return conn, cur def init_db(): + config = get_config() conn, cur = create_session() cur.execute("""CREATE TABLE IF NOT EXISTS auth (auth_key TEXT PRIMARY KEY, description TEXT)""") cur.execute("""CREATE TABLE IF NOT EXISTS prefixes ( @@ -53,8 +54,13 @@ def init_db(): ORDER BY b.prefix, t.last_name, t.first_name, t.phone_number, b.basket_id""") cur.execute("""CREATE VIEW IF NOT EXISTS counts AS SELECT prefix, COUNT(DISTINCT(CONCAT(first_name, last_name, phone_number))) AS unique_buyers, COUNT(*) AS total_buys + FROM tickets GROUP BY prefix UNION ALL - SELECT 'Total', COUNT(DISTINCT(CONCAT(first_name, last_name, phone_number))), COUNT(*)""") + SELECT 'Total', COUNT(DISTINCT(CONCAT(first_name, last_name, phone_number))), COUNT(*) + FROM tickets""") + if config["mode"] != "prod": + cur.execute("""REPLACE INTO auth VALUES ('2RO2T7GET9S7X64JUFN67OAV', 'Testing')""") conn.commit() + conn.close() print("DB initiated.") diff --git a/apiserver/src/core/models.py b/apiserver/src/core/models.py index 377d465..9c18082 100644 --- a/apiserver/src/core/models.py +++ b/apiserver/src/core/models.py @@ -7,6 +7,8 @@ from .db import create_session class RepoTemplate: def __init__(self): self.conn, self.cur = create_session() + def __del__(self): + self.conn.close() choose_from = string.ascii_uppercase + string.digits diff --git a/apiserver/src/data/models.py b/apiserver/src/data/models.py index b9f555d..e6f9768 100644 --- a/apiserver/src/data/models.py +++ b/apiserver/src/data/models.py @@ -15,6 +15,12 @@ class Ticket: phone_number: str = "" pref: str = "" +@dataclass +class Count: + prefix: str = "" + unique_buyers: int = 0 + total_buys: int = 0 + @dataclass class Basket: prefix: str = "" diff --git a/apiserver/src/data/repos.py b/apiserver/src/data/repos.py index ad6ab04..9f04be0 100644 --- a/apiserver/src/data/repos.py +++ b/apiserver/src/data/repos.py @@ -1,5 +1,5 @@ from ..core.models import RepoTemplate -from .models import Prefix, Ticket +from .models import Prefix, Ticket, Count class PrefixRepo(RepoTemplate): @@ -73,3 +73,11 @@ class TicketRepo(RepoTemplate): pref = EXCLUDED.pref""", (t.prefix, t.ticket_id, t.first_name, t.last_name, t.phone_number, t.pref)) self.conn.commit() return {"detail": "Tickets posted successfully."} + +class CountsRepo(RepoTemplate): + """Repo that controls the counts system.""" + + def get_counts(self): + self.cur.execute("SELECT * FROM counts") + results = self.cur.fetchall() + return [Count(*r) for r in results] diff --git a/apiserver/src/data/routers.py b/apiserver/src/data/routers.py index be9c0be..a156f34 100644 --- a/apiserver/src/data/routers.py +++ b/apiserver/src/data/routers.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Header, status, HTTPException -from .repos import PrefixRepo, TicketRepo +from .repos import PrefixRepo, TicketRepo, CountsRepo from .models import Prefix, Ticket from ..core.auth import AuthRepo @@ -61,3 +61,10 @@ def get_ticket_scope(prefix: str, s_id: str, tam_auth_key: str = Header("")) -> def post_tickets(ts: list[Ticket], tam_auth_key: str = Header("")): AuthRepo().verify_key(tam_auth_key) return TicketRepo().post_tickets(ts) + +counts_router = APIRouter(prefix="/api/counts") + +@counts_router.get("") +def get_counts(tam_auth_key: str = Header("")): + AuthRepo().verify_key(tam_auth_key) + return CountsRepo().get_counts() diff --git a/apiserver/src/routers.py b/apiserver/src/routers.py index 5c3b350..588d5f8 100644 --- a/apiserver/src/routers.py +++ b/apiserver/src/routers.py @@ -7,3 +7,4 @@ def append_routers(app: FastAPI): app.include_router(auth_router) app.include_router(routers.prefix_router) app.include_router(routers.ticket_router) + app.include_router(routers.counts_router) diff --git a/apiserver/tests/ticketstress.py b/apiserver/tests/ticketstress.py new file mode 100644 index 0000000..267ce5c --- /dev/null +++ b/apiserver/tests/ticketstress.py @@ -0,0 +1,67 @@ +import httpx +import threading +import random as r +import names as n +import json +from datetime import datetime +from time import sleep + +auth_key = "2RO2T7GET9S7X64JUFN67OAV" +auth_header = {"tam-auth-key": auth_key} + +prefix = "Spectacular" +start_num = 630001 +each_thread = 1000 +total = 6000 +unique_buyers = 200 + +batch_times = [] + +class Person: + def __init__(self): + first_digit = r.randint(0,1) + rest_digits = [r.randint(1,9) for _ in range(9)] + def nd(): + return rest_digits.pop(0) + self.first_name = n.get_first_name() + self.last_name = n.get_last_name() + self.phone_number = f"{first_digit}{nd()}{nd()}-{nd()}{nd()}{nd()}-{nd()}{nd()}{nd()}{nd()}" + self.pref = r.choice(["TEXT" for _ in range(3)] + ["CALL"]) + +buyer_pool = [Person().__dict__ for _ in range(unique_buyers)] + +def insert_batch(batch: list): + httpx.post("http://localhost:8000/api/tickets", headers=auth_header, json=batch) + now_time, batch_range = datetime.now().isoformat(), f"{batch[0]['ticket_id']} - {batch[-1]['ticket_id']}" + batch_len = len(batch) + print(f"Inserted a batch of {batch_len} at {now_time} along range of {batch_range}.") + batch_times.append({"Inserted time": now_time, "batch length": batch_len, "batch range": batch_range}) + batch.clear() + rand_sleep = r.randint(1,30) + print(f"Sleeping for {rand_sleep} ms.") + sleep(rand_sleep / 1000) + +def run_thread(begin_num: int): + batch = [] + for i in range(begin_num, begin_num + each_thread): + row = {"prefix": prefix, "ticket_id": i, **r.choice(buyer_pool)} + batch.append(row) + if len(batch) >= 20: + insert_batch(batch) + if len(batch) > 0: + insert_batch(batch) + +threads = [] + +for i in range(start_num, start_num + total, each_thread): + t = threading.Thread(target=run_thread, args=[i]) + threads.append(t) + +for t in threads: + t.start() + +for t in threads: + t.join() + +with open("stress_test.json", "w") as f: + json.dump(batch_times, f, indent=2)