diff --git a/app/database.py b/app/database.py index e937541..3108be1 100644 --- a/app/database.py +++ b/app/database.py @@ -4,6 +4,7 @@ from datetime import datetime, timedelta from typing import Optional import uuid import sqlalchemy as sa +from fastapi import Request from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy import Column, String, DateTime, Boolean, ForeignKey, select @@ -156,6 +157,6 @@ class ConversationManager: print(f"Cleaned up {len(conv_ids)} old conversations") -async def get_conversation_manager() -> ConversationManager: - """Dependency for FastAPI to get a ConversationManager instance.""" - return ConversationManager(settings.db_url) +async def get_conversation_manager(request: Request) -> ConversationManager: + """Dependency for FastAPI to get the singleton ConversationManager from app state.""" + return request.app.state.db_manager diff --git a/app/discord_bot.py b/app/discord_bot.py index 0504c08..83f0e69 100644 --- a/app/discord_bot.py +++ b/app/discord_bot.py @@ -3,10 +3,8 @@ import discord from discord import app_commands from discord.ext import commands -import asyncio import aiohttp from typing import Optional -import uuid from .config import settings @@ -51,6 +49,9 @@ class StratChatbotBot(commands.Bot): parent_message_id: Optional[str] = None, ) -> dict: """Send a request to the FastAPI chat endpoint.""" + if not self.session: + raise RuntimeError("Bot HTTP session not initialized") + payload = { "message": message, "user_id": user_id, @@ -76,6 +77,8 @@ bot = StratChatbotBot() @bot.event async def on_ready(): """Called when the bot is ready.""" + if not bot.user: + return print(f"🤖 Bot logged in as {bot.user} (ID: {bot.user.id})") print("Ready to answer Strat-O-Matic rules questions!") @@ -123,10 +126,9 @@ async def ask_command(interaction: discord.Interaction, question: str): inline=False, ) - # Add conversation ID for follow-ups + # Add conversation ID for follow-ups (full UUID so replies can be threaded) embed.set_footer( - text=f"Conversation: {result['conversation_id'][:8]}... " - f"| Reply to ask a follow-up" + text=f"conv:{result['conversation_id']} | Reply to ask a follow-up" ) await interaction.followup.send(embed=embed) @@ -149,7 +151,7 @@ async def on_message(message: discord.Message): return # Only handle replies to the bot's messages - if not message.reference: + if not message.reference or message.reference.message_id is None: return referenced = await message.channel.fetch_message(message.reference.message_id) @@ -168,31 +170,36 @@ async def on_message(message: discord.Message): return footer_text = embed.footer.text or "" - if "Conversation:" not in footer_text: + if "conv:" not in footer_text: await message.reply( "❓ Could not determine conversation. Use `/ask` to start fresh.", mention_author=True, ) return - # Extract conversation ID (rough parsing) + # Extract full conversation UUID from "conv: | ..." format try: - conv_id = footer_text.split("Conversation:")[1].split("...")[0].strip() - conversation_id = footer_text.split("Conversation:")[1].split(" ")[0].strip() - except: - conversation_id = None + conversation_id = footer_text.split("conv:")[1].split(" ")[0].strip() + except (IndexError, AttributeError): + await message.reply( + "❓ Could not parse conversation ID. Use `/ask` to start fresh.", + mention_author=True, + ) + return # Get parent message ID (the original answer message) parent_message_id = str(referenced.id) - # Forward to API - await message.reply("🔍 Looking into that follow-up...", mention_author=True) + # Send a loading placeholder and replace it with the real answer when ready + loading_msg = await message.reply( + "🔍 Looking into that follow-up...", mention_author=True + ) try: result = await bot.query_chat_api( message=message.content, user_id=str(message.author.id), - channel_id=str(message.channel_id), + channel_id=str(message.channel.id), conversation_id=conversation_id, parent_message_id=parent_message_id, ) @@ -213,19 +220,25 @@ async def on_message(message: discord.Message): if result.get("confidence", 0.0) < 0.4: response_embed.add_field( name="⚠️ Confidence", - value=f"Low - Human review requested", + value="Low - Human review requested", inline=False, ) - await message.reply(embed=response_embed, mention_author=True) + # Carry the conversation ID forward so further replies stay in the same thread + response_embed.set_footer( + text=f"conv:{result['conversation_id']} | Reply to ask a follow-up" + ) + + await loading_msg.edit(content=None, embed=response_embed) except Exception as e: - await message.reply( + await loading_msg.edit( + content=None, embed=discord.Embed( title="❌ Error", description=f"Failed to process follow-up: {str(e)}", color=discord.Color.red(), - ) + ), ) diff --git a/app/gitea.py b/app/gitea.py index 9d72622..a9522f2 100644 --- a/app/gitea.py +++ b/app/gitea.py @@ -19,17 +19,6 @@ class GiteaClient: "Content-Type": "application/json", "Accept": "application/json", } - self._client: Optional[httpx.AsyncClient] = None - - async def __aenter__(self): - """Async context manager entry.""" - self._client = httpx.AsyncClient(timeout=30.0) - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Async context manager exit.""" - if self._client: - await self._client.aclose() async def create_issue( self, @@ -39,12 +28,9 @@ class GiteaClient: assignee: Optional[str] = None, ) -> dict: """Create a new issue in the configured repository.""" - if not self._client: - raise RuntimeError("GiteaClient must be used as async context manager") - url = f"{self.base_url}/repos/{self.owner}/{self.repo}/issues" - payload = {"title": title, "body": body} + payload: dict = {"title": title, "body": body} if labels: payload["labels"] = labels @@ -52,7 +38,8 @@ class GiteaClient: if assignee: payload["assignee"] = assignee - response = await self._client.post(url, headers=self.headers, json=payload) + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post(url, headers=self.headers, json=payload) if response.status_code not in (200, 201): error_detail = response.text diff --git a/app/vector_store.py b/app/vector_store.py index 3fc85f7..ddb9b9e 100644 --- a/app/vector_store.py +++ b/app/vector_store.py @@ -85,7 +85,9 @@ class VectorStore: for i in range(len(results["documents"][0])): metadata = results["metadatas"][0][i] distance = results["distances"][0][i] - similarity = 1 - distance # Convert cosine distance to similarity + similarity = max( + 0.0, min(1.0, 1 - distance) + ) # Clamp to [0, 1]: cosine distance ranges 0–2 search_results.append( RuleSearchResult(