Add query_to_csv

This commit is contained in:
Cal Corum 2023-09-15 00:03:23 -05:00
parent bc37568c8e
commit 5fae0a30df
2 changed files with 32 additions and 4 deletions

View File

@ -2,9 +2,11 @@ import copy
import datetime import datetime
import logging import logging
import math import math
from typing import Literal from typing import Literal, List
from pandas import DataFrame
from peewee import * from peewee import *
from peewee import ModelSelect
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
db = SqliteDatabase( db = SqliteDatabase(
@ -40,6 +42,27 @@ WEEK_NUMS = {
} }
def model_csv_headers(this_obj, exclude=None) -> List:
if this_obj is None:
return ['None']
data = model_to_dict(this_obj, recurse=False, exclude=exclude)
return [x for x in data.keys()]
def model_to_csv(this_obj, exclude=None) -> List:
data = model_to_dict(this_obj, recurse=False, exclude=exclude)
return [x for x in data.values()]
def query_to_csv(all_items: ModelSelect, exclude=None):
data_list = [model_csv_headers(all_items[0], exclude=exclude)]
for x in all_items:
data_list.append(model_to_csv(x, exclude=exclude))
return DataFrame(data_list).to_csv(header=False, index=False)
def per_season_weeks(season: int, s_type: Literal['regular', 'post', 'total']): def per_season_weeks(season: int, s_type: Literal['regular', 'post', 'total']):
if season == 1: if season == 1:
if s_type == 'regular': if s_type == 'regular':

View File

@ -1,10 +1,10 @@
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query, Response
from typing import List, Optional, Literal from typing import List, Optional, Literal
import copy import copy
import logging import logging
import pydantic import pydantic
from ..db_engine import db, Team, Manager, Division, model_to_dict, chunked, fn from ..db_engine import db, Team, Manager, Division, model_to_dict, chunked, fn, query_to_csv
from ..dependencies import oauth2_scheme, valid_token, LOG_DATA from ..dependencies import oauth2_scheme, valid_token, LOG_DATA
logging.basicConfig( logging.basicConfig(
@ -43,7 +43,7 @@ class TeamList(pydantic.BaseModel):
async def get_teams( async def get_teams(
season: Optional[int] = None, owner_id: list = Query(default=None), manager_id: list = Query(default=None), season: Optional[int] = None, owner_id: list = Query(default=None), manager_id: list = Query(default=None),
team_abbrev: list = Query(default=None), active_only: Optional[bool] = False, team_abbrev: list = Query(default=None), active_only: Optional[bool] = False,
short_output: Optional[bool] = False): short_output: Optional[bool] = False, csv: Optional[bool] = False):
if season is not None: if season is not None:
all_teams = Team.select_season(season) all_teams = Team.select_season(season)
else: else:
@ -64,6 +64,11 @@ async def get_teams(
~(Team.abbrev.endswith('IL')) & ~(Team.abbrev.endswith('MiL')) ~(Team.abbrev.endswith('IL')) & ~(Team.abbrev.endswith('MiL'))
) )
if csv:
return_val = query_to_csv(all_teams, exclude=[Team.division_legacy, Team.mascot, Team.gsheet])
db.close()
return Response(content=return_val, media_type='text/csv')
return_teams = { return_teams = {
'count': all_teams.count(), 'count': all_teams.count(),
'teams': [model_to_dict(x, recurse=not short_output) for x in all_teams] 'teams': [model_to_dict(x, recurse=not short_output) for x in all_teams]