144 lines
5.5 KiB
Python
144 lines
5.5 KiB
Python
from fastapi import FastAPI, Request, Header, HTTPException
|
|
from contextlib import asynccontextmanager
|
|
from starlette.responses import Response
|
|
from jose import jwt
|
|
from datetime import datetime, timedelta, timezone
|
|
from cachetools import TTLCache
|
|
from logging.handlers import RotatingFileHandler
|
|
import asyncpg
|
|
import httpx
|
|
import logging
|
|
import os
|
|
|
|
logger = logging.getLogger('apiproxy')
|
|
logger.setLevel(logging.INFO)
|
|
|
|
handler = RotatingFileHandler(
|
|
filename='logs/apiproxy.log',
|
|
# encoding='utf-8',
|
|
maxBytes=32 * 1024 * 1024, # 32 MiB
|
|
backupCount=5, # Rotate through 5 files
|
|
)
|
|
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
|
|
app = FastAPI()
|
|
logger.info(f'\n* * * * * * * * * * * *\nInitializing Paper Dynasty API Proxy\n* * * * * * * * * * * *')
|
|
|
|
# Env config
|
|
JWT_SECRET = os.getenv("JWT_SECRET", "super-secret")
|
|
JWT_ALGORITHM = "HS256"
|
|
POSTGREST_URL = os.getenv("POSTGREST_URL", "http://postgrest:3000")
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:pass@localhost:5432/mydb")
|
|
PRODUCTION = True if os.getenv("PRODUCTION", '').lower() == 'true' else False
|
|
|
|
# TTL cache: 1000 keys, 10 min TTL
|
|
api_key_cache = TTLCache(maxsize=1000, ttl=600)
|
|
logger.info(f'api_key_cache is instantiated')
|
|
|
|
# Postgres connection pool
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger.info(f'entering fastapi lifespan function / PRODUCTION: {PRODUCTION}')
|
|
if PRODUCTION:
|
|
app.state.db = await asyncpg.create_pool(DATABASE_URL)
|
|
logger.info(f'database pool is instantiated')
|
|
try:
|
|
yield
|
|
finally:
|
|
await app.state.db.close()
|
|
|
|
async def fetch_user_from_db(db, api_key: str):
|
|
# consider making discord ID the api key and pull from teams table
|
|
row = await db.fetchrow("SELECT user_id, role FROM api_keys WHERE key = $1 AND active = true", api_key)
|
|
if row:
|
|
return {"user_id": row["user_id"], "role": row["role"]}
|
|
return None
|
|
|
|
def fetch_user_in_dev(api_key: str):
|
|
fake_db = {
|
|
"key-alice": {"user_id": "alice", "role": "authenticated"},
|
|
"key-bob": {"user_id": "bob", "role": "user"},
|
|
}
|
|
return fake_db.get(api_key)
|
|
|
|
def generate_jwt(user_id: str, role: str, exp_seconds=3600):
|
|
payload = {
|
|
"sub": user_id,
|
|
"role": role,
|
|
"exp": datetime.now(timezone.utc) + timedelta(seconds=exp_seconds),
|
|
"iat": datetime.now(timezone.utc)
|
|
}
|
|
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
|
|
|
logger.info(f'attaching proxy_postgrest function to fastapi app')
|
|
@app.api_route("/api/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
|
|
async def proxy_postgrest(
|
|
path: str,
|
|
request: Request,
|
|
x_api_key: str = Header(default=None)
|
|
):
|
|
log_prefix = f'{getattr(getattr(request, "client", None), "host", None)} - {datetime.now().timestamp()}'
|
|
logger.info(f'{log_prefix} - an API call was made to: /{path} with api key: {x_api_key}')
|
|
if not x_api_key:
|
|
raise HTTPException(status_code=401, detail="Unauthorized access")
|
|
|
|
# Step 1: Check cache
|
|
user_info = api_key_cache.get(x_api_key)
|
|
logger.info(f'{log_prefix} - cached user: {user_info}')
|
|
if not user_info:
|
|
# Step 2: Cache miss → look up in DB
|
|
logger.info(f'{log_prefix} - in prod: {PRODUCTION}')
|
|
if PRODUCTION:
|
|
logger.info(f'{log_prefix} - looking up user in prod db')
|
|
user_info = await fetch_user_from_db(app.state.db, x_api_key)
|
|
else:
|
|
logger.info(f'{log_prefix} - looking up user in fake db')
|
|
user_info = fetch_user_in_dev(x_api_key)
|
|
logger.info(f'{log_prefix} - user_info: {user_info}')
|
|
if not user_info:
|
|
raise HTTPException(status_code=401, detail="Invalid or inactive API key")
|
|
logger.info(f'{log_prefix} - caching {x_api_key} for {user_info}')
|
|
api_key_cache[x_api_key] = user_info # Step 3: Cache it
|
|
|
|
# Step 4: Sign JWT
|
|
logger.info(f'{log_prefix} - generating jwt for postgrest')
|
|
token = generate_jwt(user_info["user_id"], user_info["role"])
|
|
|
|
# Step 5: Forward request to PostgREST
|
|
method = request.method
|
|
body = await request.body()
|
|
headers = dict(request.headers)
|
|
logger.info(f'{log_prefix} - incoming headers: {headers}')
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
forward_headers = {
|
|
'Authorization': f'Bearer {token}',
|
|
'Content-Type': headers.get("Content-Type", headers.get("content-type", "application/json")),
|
|
'Accept': headers.get('Accept', headers.get('accept', '*/*'))
|
|
}
|
|
|
|
# TODO: proxy is currently not honoring Accept: text/csv
|
|
# we were logging 'headers' but passing 'forward_headers'
|
|
async with httpx.AsyncClient() as client:
|
|
logger.info(f'{log_prefix} - sending request to postgrest for {user_info}:\nMethod: {method}\nURL: {POSTGREST_URL}/{path}\nHeaders: {headers}\nBody: {body}\nParams: {request.query_params}')
|
|
response = await client.request(
|
|
method=method,
|
|
url=f"{POSTGREST_URL}/{path}",
|
|
headers=headers,
|
|
content=body,
|
|
params=request.query_params
|
|
)
|
|
|
|
logger.info(f'{log_prefix} - {user_info} / Response Code: {response.status_code}')
|
|
if response.status_code != 200:
|
|
logger.warning(f'{log_prefix} - Response Content: {response.content}')
|
|
return Response(
|
|
content=response.content,
|
|
status_code=response.status_code,
|
|
headers=response.headers
|
|
)
|
|
|
|
logger.info(f'end of main.py')
|