Initial commit: Voice server with Piper TTS
A local HTTP service that accepts text via POST and speaks it through system speakers using Piper TTS neural voice synthesis. Features: - POST /notify - Queue text for TTS playback - GET /health - Health check with TTS/audio/queue status - GET /voices - List installed voice models - Async queue processing (no overlapping audio) - Non-blocking audio via sounddevice - 73 tests covering API contract Tech stack: - FastAPI + Uvicorn - Piper TTS (neural voices, offline) - sounddevice (PortAudio) - Pydantic for validation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
commit
a34aec06f1
22
.env.example
Normal file
22
.env.example
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# Voice Server Configuration
|
||||||
|
# Copy this file to .env and modify as needed
|
||||||
|
|
||||||
|
# Server Settings
|
||||||
|
HOST=0.0.0.0
|
||||||
|
PORT=8888
|
||||||
|
|
||||||
|
# TTS Settings
|
||||||
|
MODEL_DIR=./models
|
||||||
|
DEFAULT_VOICE=en_US-lessac-medium
|
||||||
|
DEFAULT_RATE=170
|
||||||
|
|
||||||
|
# Queue Settings
|
||||||
|
QUEUE_MAX_SIZE=50
|
||||||
|
REQUEST_TIMEOUT_SECONDS=60
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
LOG_FILE=voice-server.log
|
||||||
|
|
||||||
|
# Debug (disable TTS for testing)
|
||||||
|
# VOICE_ENABLED=true
|
||||||
58
.gitignore
vendored
Normal file
58
.gitignore
vendored
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# Voice models (large files)
|
||||||
|
models/*.onnx
|
||||||
|
models/*.onnx.json
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
.pytest_cache/
|
||||||
|
.tox/
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# uv
|
||||||
|
uv.lock
|
||||||
1012
PROJECT_ROADMAP.json
Normal file
1012
PROJECT_ROADMAP.json
Normal file
File diff suppressed because it is too large
Load Diff
41
README.md
Normal file
41
README.md
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Voice Server
|
||||||
|
|
||||||
|
Local HTTP service for text-to-speech playback using Piper TTS.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- HTTP POST endpoint for text-to-speech requests
|
||||||
|
- High-quality neural TTS using Piper
|
||||||
|
- Non-blocking audio playback with sounddevice
|
||||||
|
- Async request queue for concurrent handling
|
||||||
|
- Automatic OpenAPI documentation
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
uv pip install -e ".[dev]"
|
||||||
|
|
||||||
|
# Run server
|
||||||
|
uvicorn app.main:app --reload
|
||||||
|
|
||||||
|
# Test endpoint
|
||||||
|
curl -X POST http://localhost:8000/notify \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"message": "Hello, world!"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
- `POST /notify` - Submit text for TTS playback
|
||||||
|
- `GET /health` - Health check endpoint
|
||||||
|
- `GET /voices` - List available voice models
|
||||||
|
- `GET /docs` - OpenAPI documentation
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
See `.env.example` for configuration options.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT
|
||||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
192
app/audio_player.py
Normal file
192
app/audio_player.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
Audio playback module for voice-server.
|
||||||
|
|
||||||
|
Provides non-blocking audio playback using sounddevice.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPlayer(Protocol):
|
||||||
|
"""Protocol defining the audio player interface."""
|
||||||
|
|
||||||
|
def play(self, audio_data: np.ndarray, sample_rate: int) -> None:
|
||||||
|
"""Play audio data (non-blocking)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def is_playing(self) -> bool:
|
||||||
|
"""Check if audio is currently playing."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop current playback."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def wait_async(self) -> None:
|
||||||
|
"""Wait asynchronously for playback to complete."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class SounddevicePlayer:
|
||||||
|
"""
|
||||||
|
Audio player implementation using sounddevice.
|
||||||
|
|
||||||
|
Provides non-blocking playback with async wait support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, default_sample_rate: int = 22050, retry_attempts: int = 3):
|
||||||
|
"""
|
||||||
|
Initialize the audio player.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_sample_rate: Default sample rate if not specified in play()
|
||||||
|
retry_attempts: Number of retry attempts on playback failure
|
||||||
|
"""
|
||||||
|
self.default_sample_rate = default_sample_rate
|
||||||
|
self.retry_attempts = retry_attempts
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
# Lazy import sounddevice to defer PortAudio initialization
|
||||||
|
self._sd = None
|
||||||
|
|
||||||
|
def _ensure_initialized(self):
|
||||||
|
"""Ensure sounddevice is imported and initialized."""
|
||||||
|
if self._sd is None:
|
||||||
|
try:
|
||||||
|
import sounddevice as sd
|
||||||
|
|
||||||
|
self._sd = sd
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("SounddevicePlayer initialized successfully")
|
||||||
|
except OSError as e:
|
||||||
|
logger.error(f"Failed to initialize sounddevice: {e}")
|
||||||
|
raise RuntimeError(f"Audio system unavailable: {e}") from e
|
||||||
|
|
||||||
|
def play(self, audio_data: np.ndarray, sample_rate: int | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Play audio data (non-blocking).
|
||||||
|
|
||||||
|
The audio plays in a background thread. Use is_playing() to check status
|
||||||
|
or wait_async() to wait for completion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: NumPy array of audio samples (float32 or int16)
|
||||||
|
sample_rate: Sample rate in Hz (uses default if not specified)
|
||||||
|
"""
|
||||||
|
self._ensure_initialized()
|
||||||
|
|
||||||
|
if len(audio_data) == 0:
|
||||||
|
logger.debug("Skipping playback of empty audio")
|
||||||
|
return
|
||||||
|
|
||||||
|
rate = sample_rate or self.default_sample_rate
|
||||||
|
|
||||||
|
# Stop any currently playing audio
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
for attempt in range(self.retry_attempts):
|
||||||
|
try:
|
||||||
|
# Play audio - returns immediately, audio plays in background
|
||||||
|
self._sd.play(audio_data, rate)
|
||||||
|
logger.debug(f"Started playback: {len(audio_data)} samples at {rate}Hz")
|
||||||
|
return
|
||||||
|
except self._sd.PortAudioError as e:
|
||||||
|
logger.warning(f"Playback attempt {attempt + 1} failed: {e}")
|
||||||
|
if attempt < self.retry_attempts - 1:
|
||||||
|
time.sleep(0.5)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Audio playback failed after {self.retry_attempts} attempts: {e}")
|
||||||
|
|
||||||
|
def is_playing(self) -> bool:
|
||||||
|
"""Check if audio is currently playing."""
|
||||||
|
if self._sd is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = self._sd.get_stream()
|
||||||
|
return stream is not None and stream.active
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop current playback."""
|
||||||
|
if self._sd is not None:
|
||||||
|
try:
|
||||||
|
self._sd.stop()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error stopping playback: {e}")
|
||||||
|
|
||||||
|
def wait(self) -> None:
|
||||||
|
"""Block until current playback completes."""
|
||||||
|
if self._sd is not None:
|
||||||
|
try:
|
||||||
|
self._sd.wait()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error waiting for playback: {e}")
|
||||||
|
|
||||||
|
async def wait_async(self, poll_interval: float = 0.05) -> None:
|
||||||
|
"""
|
||||||
|
Wait asynchronously for playback to complete.
|
||||||
|
|
||||||
|
Uses polling to avoid blocking the event loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
poll_interval: How often to check playback status (seconds)
|
||||||
|
"""
|
||||||
|
while self.is_playing():
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
def get_diagnostics(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get audio system diagnostics for health checks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with audio device information and status
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._ensure_initialized()
|
||||||
|
|
||||||
|
devices = self._sd.query_devices()
|
||||||
|
output_devices = [d for d in devices if d["max_output_channels"] > 0]
|
||||||
|
|
||||||
|
if not output_devices:
|
||||||
|
return {
|
||||||
|
"status": "unavailable",
|
||||||
|
"error": "No audio output devices found",
|
||||||
|
}
|
||||||
|
|
||||||
|
default_output = self._sd.query_devices(kind="output")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "available",
|
||||||
|
"device_count": len(output_devices),
|
||||||
|
"default_output": default_output["name"],
|
||||||
|
"default_sample_rate": int(default_output["default_samplerate"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "unavailable",
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
def health_check(self) -> dict:
|
||||||
|
"""
|
||||||
|
Perform a health check on the audio system.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and any error messages
|
||||||
|
"""
|
||||||
|
diagnostics = self.get_diagnostics()
|
||||||
|
|
||||||
|
if diagnostics["status"] == "available":
|
||||||
|
return {"status": "healthy", "details": diagnostics}
|
||||||
|
else:
|
||||||
|
return {"status": "unhealthy", "error": diagnostics.get("error", "Unknown error")}
|
||||||
98
app/config.py
Normal file
98
app/config.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
Configuration management for voice-server.
|
||||||
|
|
||||||
|
Loads configuration from environment variables with sensible defaults.
|
||||||
|
Uses pydantic-settings for type-safe configuration loading and validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
from pydantic import Field, field_validator
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
# Valid log levels
|
||||||
|
LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""
|
||||||
|
Application settings loaded from environment variables.
|
||||||
|
|
||||||
|
All settings have sensible defaults and can be overridden via environment
|
||||||
|
variables or a .env file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
case_sensitive=False,
|
||||||
|
extra="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Server settings
|
||||||
|
host: Annotated[str, Field(default="0.0.0.0", description="Host to bind to")]
|
||||||
|
port: Annotated[
|
||||||
|
int,
|
||||||
|
Field(default=8888, ge=1, le=65535, description="Port to listen on"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# TTS settings
|
||||||
|
model_dir: Annotated[
|
||||||
|
str,
|
||||||
|
Field(default="./models", description="Directory containing voice models"),
|
||||||
|
]
|
||||||
|
default_voice: Annotated[
|
||||||
|
str,
|
||||||
|
Field(default="en_US-lessac-medium", description="Default voice model"),
|
||||||
|
]
|
||||||
|
default_rate: Annotated[
|
||||||
|
int,
|
||||||
|
Field(default=170, ge=50, le=400, description="Default speech rate (WPM)"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Queue settings
|
||||||
|
queue_max_size: Annotated[
|
||||||
|
int,
|
||||||
|
Field(default=50, gt=0, description="Maximum TTS queue size"),
|
||||||
|
]
|
||||||
|
request_timeout_seconds: Annotated[
|
||||||
|
int,
|
||||||
|
Field(default=60, gt=0, description="Request processing timeout"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_level: Annotated[
|
||||||
|
LogLevel,
|
||||||
|
Field(default="INFO", description="Logging level"),
|
||||||
|
]
|
||||||
|
log_file: Annotated[
|
||||||
|
str | None,
|
||||||
|
Field(default=None, description="Log file path (None for stdout only)"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Debug
|
||||||
|
voice_enabled: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(default=True, description="Enable/disable TTS playback"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@field_validator("log_level", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def uppercase_log_level(cls, v: str) -> str:
|
||||||
|
"""Ensure log level is uppercase."""
|
||||||
|
if isinstance(v, str):
|
||||||
|
return v.upper()
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""
|
||||||
|
Get cached application settings.
|
||||||
|
|
||||||
|
Returns the same Settings instance on subsequent calls for efficiency.
|
||||||
|
The cache can be cleared by calling get_settings.cache_clear().
|
||||||
|
"""
|
||||||
|
return Settings()
|
||||||
140
app/main.py
Normal file
140
app/main.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
"""
|
||||||
|
Voice Server - Local HTTP service for text-to-speech playback.
|
||||||
|
|
||||||
|
This module provides the FastAPI application with endpoints for:
|
||||||
|
- POST /notify: Submit text for TTS playback
|
||||||
|
- GET /health: Health check endpoint
|
||||||
|
- GET /voices: List available voice models
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.audio_player import SounddevicePlayer
|
||||||
|
from app.config import get_settings
|
||||||
|
from app.queue_manager import TTSQueueManager
|
||||||
|
from app.tts_engine import PiperTTSEngine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Track server start time for uptime calculation
|
||||||
|
_start_time: float = 0.0
|
||||||
|
|
||||||
|
# Global instances (initialized in lifespan)
|
||||||
|
tts_engine: PiperTTSEngine | None = None
|
||||||
|
audio_player: SounddevicePlayer | None = None
|
||||||
|
queue_manager: TTSQueueManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""
|
||||||
|
Application lifespan handler.
|
||||||
|
|
||||||
|
Handles startup and shutdown events:
|
||||||
|
- Startup: Initialize TTS engine, audio player, queue processor
|
||||||
|
- Shutdown: Stop audio playback, drain queue
|
||||||
|
"""
|
||||||
|
global _start_time, tts_engine, audio_player, queue_manager
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
_start_time = time.time()
|
||||||
|
|
||||||
|
# Initialize TTS engine
|
||||||
|
logger.info("Initializing TTS engine...")
|
||||||
|
tts_engine = PiperTTSEngine(
|
||||||
|
model_dir=settings.model_dir,
|
||||||
|
default_voice=settings.default_voice,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize audio player
|
||||||
|
logger.info("Initializing audio player...")
|
||||||
|
audio_player = SounddevicePlayer(
|
||||||
|
default_sample_rate=tts_engine.get_sample_rate(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize and start queue manager
|
||||||
|
logger.info("Starting queue manager...")
|
||||||
|
queue_manager = TTSQueueManager(
|
||||||
|
tts_engine=tts_engine,
|
||||||
|
audio_player=audio_player,
|
||||||
|
max_size=settings.queue_max_size,
|
||||||
|
request_timeout=settings.request_timeout_seconds,
|
||||||
|
)
|
||||||
|
await queue_manager.start()
|
||||||
|
|
||||||
|
logger.info("Voice server started successfully")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown cleanup
|
||||||
|
logger.info("Shutting down voice server...")
|
||||||
|
if queue_manager:
|
||||||
|
await queue_manager.stop()
|
||||||
|
if audio_player:
|
||||||
|
audio_player.stop()
|
||||||
|
logger.info("Voice server stopped")
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""
|
||||||
|
Create and configure the FastAPI application.
|
||||||
|
|
||||||
|
Returns a configured FastAPI instance with all routes and middleware.
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Voice Server",
|
||||||
|
description="Local HTTP service for text-to-speech playback using Piper TTS",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure CORS
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register routes
|
||||||
|
from app.routes import router
|
||||||
|
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def get_uptime_seconds() -> int:
|
||||||
|
"""Get server uptime in seconds."""
|
||||||
|
if _start_time == 0.0:
|
||||||
|
return 0
|
||||||
|
return int(time.time() - _start_time)
|
||||||
|
|
||||||
|
|
||||||
|
# Create the application instance
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
"""Run the server using uvicorn (for CLI entry point)."""
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
uvicorn.run(
|
||||||
|
"app.main:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
162
app/models.py
Normal file
162
app/models.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
Pydantic models for voice-server request/response validation.
|
||||||
|
|
||||||
|
This module defines the API contract for all endpoints:
|
||||||
|
- NotifyRequest/NotifyResponse: POST /notify
|
||||||
|
- HealthResponse: GET /health
|
||||||
|
- VoicesResponse: GET /voices
|
||||||
|
- ErrorResponse: Error responses
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class NotifyRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Request model for POST /notify endpoint.
|
||||||
|
|
||||||
|
Validates incoming TTS requests with message content and optional parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
message: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
min_length=1,
|
||||||
|
max_length=10000,
|
||||||
|
description="Text to convert to speech (1-10000 characters)",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
voice: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="en_US-lessac-medium",
|
||||||
|
pattern=r"^[\w-]+$",
|
||||||
|
description="Piper voice model name",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
rate: Annotated[
|
||||||
|
int,
|
||||||
|
Field(
|
||||||
|
default=170,
|
||||||
|
ge=50,
|
||||||
|
le=400,
|
||||||
|
description="Speech rate in words per minute (50-400)",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
voice_enabled: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable/disable TTS playback (for debugging)",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@field_validator("message", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_message_whitespace(cls, v: str) -> str:
|
||||||
|
"""Strip leading and trailing whitespace from message."""
|
||||||
|
if isinstance(v, str):
|
||||||
|
return v.strip()
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class NotifyResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Response model for successful POST /notify requests.
|
||||||
|
|
||||||
|
Returned when a TTS request is successfully queued for processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: Annotated[str, Field(description="Request status (e.g., 'queued')")]
|
||||||
|
message_length: Annotated[int, Field(description="Length of the message in characters")]
|
||||||
|
queue_position: Annotated[int, Field(description="Position in the TTS queue")]
|
||||||
|
voice_model: Annotated[str, Field(description="Voice model being used")]
|
||||||
|
estimated_duration: Annotated[
|
||||||
|
float | None,
|
||||||
|
Field(default=None, description="Estimated playback duration in seconds"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class QueueStatus(BaseModel):
|
||||||
|
"""Queue status information for health checks."""
|
||||||
|
|
||||||
|
size: Annotated[int, Field(description="Current number of items in queue")]
|
||||||
|
capacity: Annotated[int, Field(description="Maximum queue capacity")]
|
||||||
|
utilization: Annotated[float, Field(description="Queue utilization percentage")]
|
||||||
|
|
||||||
|
|
||||||
|
class HealthResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Response model for GET /health endpoint.
|
||||||
|
|
||||||
|
Provides comprehensive health status including TTS engine, audio, and queue status.
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: Annotated[str, Field(description="Overall health status ('healthy' or 'unhealthy')")]
|
||||||
|
uptime_seconds: Annotated[int, Field(description="Server uptime in seconds")]
|
||||||
|
queue: Annotated[QueueStatus, Field(description="Queue status information")]
|
||||||
|
tts_engine: Annotated[str, Field(description="TTS engine name")]
|
||||||
|
audio_output: Annotated[str, Field(description="Audio output status")]
|
||||||
|
voice_models_loaded: Annotated[
|
||||||
|
list[str] | None,
|
||||||
|
Field(default=None, description="List of loaded voice models"),
|
||||||
|
]
|
||||||
|
total_requests: Annotated[
|
||||||
|
int | None,
|
||||||
|
Field(default=None, description="Total requests processed"),
|
||||||
|
]
|
||||||
|
failed_requests: Annotated[
|
||||||
|
int | None,
|
||||||
|
Field(default=None, description="Number of failed requests"),
|
||||||
|
]
|
||||||
|
errors: Annotated[
|
||||||
|
list[str] | None,
|
||||||
|
Field(default=None, description="List of error messages if unhealthy"),
|
||||||
|
]
|
||||||
|
timestamp: Annotated[
|
||||||
|
datetime,
|
||||||
|
Field(default_factory=lambda: datetime.now(timezone.utc), description="Response timestamp"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Response model for error conditions.
|
||||||
|
|
||||||
|
Used for 4xx and 5xx error responses with consistent structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
error: Annotated[str, Field(description="Error type identifier")]
|
||||||
|
detail: Annotated[str, Field(description="Human-readable error description")]
|
||||||
|
timestamp: Annotated[
|
||||||
|
datetime,
|
||||||
|
Field(default_factory=lambda: datetime.now(timezone.utc), description="Error timestamp"),
|
||||||
|
]
|
||||||
|
queue_size: Annotated[
|
||||||
|
int | None,
|
||||||
|
Field(default=None, description="Current queue size (for queue_full errors)"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceInfo(BaseModel):
|
||||||
|
"""Information about a single voice model."""
|
||||||
|
|
||||||
|
name: Annotated[str, Field(description="Voice model name")]
|
||||||
|
language: Annotated[str, Field(description="Language code (e.g., 'en_US')")]
|
||||||
|
quality: Annotated[str, Field(description="Quality level (low, medium, high)")]
|
||||||
|
size_mb: Annotated[float, Field(description="Model size in megabytes")]
|
||||||
|
installed: Annotated[bool, Field(description="Whether the model is installed locally")]
|
||||||
|
|
||||||
|
|
||||||
|
class VoicesResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Response model for GET /voices endpoint.
|
||||||
|
|
||||||
|
Lists available voice models and the default voice.
|
||||||
|
"""
|
||||||
|
|
||||||
|
voices: Annotated[list[VoiceInfo], Field(description="List of available voices")]
|
||||||
|
default_voice: Annotated[str, Field(description="Default voice model name")]
|
||||||
236
app/queue_manager.py
Normal file
236
app/queue_manager.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
"""
|
||||||
|
TTS Queue Manager for voice-server.
|
||||||
|
|
||||||
|
Manages an async queue of TTS requests and processes them sequentially.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QueueFullError(Exception):
|
||||||
|
"""Raised when the TTS queue is full."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TTSRequest:
|
||||||
|
"""A TTS request in the queue."""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
voice: str
|
||||||
|
rate: int
|
||||||
|
voice_enabled: bool
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
request_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QueueStats:
|
||||||
|
"""Statistics about queue processing."""
|
||||||
|
|
||||||
|
processed: int = 0
|
||||||
|
errors: int = 0
|
||||||
|
total_audio_seconds: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TTSQueueManager:
|
||||||
|
"""
|
||||||
|
Manages the TTS request queue and processes requests sequentially.
|
||||||
|
|
||||||
|
Ensures audio doesn't overlap by processing one request at a time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tts_engine: Any,
|
||||||
|
audio_player: Any,
|
||||||
|
max_size: int = 50,
|
||||||
|
request_timeout: float = 60.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the queue manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tts_engine: TTS engine instance for synthesis
|
||||||
|
audio_player: Audio player instance for playback
|
||||||
|
max_size: Maximum queue size
|
||||||
|
request_timeout: Timeout for processing each request (seconds)
|
||||||
|
"""
|
||||||
|
self.tts_engine = tts_engine
|
||||||
|
self.audio_player = audio_player
|
||||||
|
self.max_size = max_size
|
||||||
|
self.request_timeout = request_timeout
|
||||||
|
|
||||||
|
self._queue: asyncio.Queue[TTSRequest] = asyncio.Queue(maxsize=max_size)
|
||||||
|
self._stats = QueueStats()
|
||||||
|
self._running = False
|
||||||
|
self._processor_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the queue processor background task."""
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._processor_task = asyncio.create_task(self._process_queue())
|
||||||
|
logger.info("TTS queue processor started")
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the queue processor and wait for current item to complete."""
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
if self._processor_task:
|
||||||
|
# Cancel the task
|
||||||
|
self._processor_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._processor_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._processor_task = None
|
||||||
|
|
||||||
|
# Stop any playing audio
|
||||||
|
self.audio_player.stop()
|
||||||
|
logger.info("TTS queue processor stopped")
|
||||||
|
|
||||||
|
async def enqueue(self, request: TTSRequest) -> int:
|
||||||
|
"""
|
||||||
|
Add a TTS request to the queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The TTS request to queue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Queue position (1-indexed)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QueueFullError: If the queue is full
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use a short timeout to avoid blocking
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._queue.put(request),
|
||||||
|
timeout=1.0,
|
||||||
|
)
|
||||||
|
position = self._queue.qsize()
|
||||||
|
logger.debug(f"Enqueued request: {request.message[:50]}... (position={position})")
|
||||||
|
return position
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise QueueFullError(f"TTS queue is full (max_size={self.max_size})")
|
||||||
|
|
||||||
|
async def _process_queue(self) -> None:
|
||||||
|
"""Background task that processes queued requests."""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
# Wait for a request (with timeout to allow checking _running)
|
||||||
|
try:
|
||||||
|
request = await asyncio.wait_for(
|
||||||
|
self._queue.get(),
|
||||||
|
timeout=1.0,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
await self._process_request(request)
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in queue processor: {e}")
|
||||||
|
self._stats.errors += 1
|
||||||
|
|
||||||
|
async def _process_request(self, request: TTSRequest) -> None:
|
||||||
|
"""
|
||||||
|
Process a single TTS request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The TTS request to process
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug(f"Processing TTS request: {request.message[:50]}...")
|
||||||
|
|
||||||
|
if not request.voice_enabled:
|
||||||
|
logger.debug("Voice disabled, skipping TTS")
|
||||||
|
self._stats.processed += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
# Synthesize audio (run in thread pool to avoid blocking)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
audio_data = await asyncio.wait_for(
|
||||||
|
loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
self.tts_engine.synthesize_to_float32,
|
||||||
|
request.message,
|
||||||
|
request.voice,
|
||||||
|
),
|
||||||
|
timeout=self.request_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(audio_data) == 0:
|
||||||
|
logger.warning("TTS generated empty audio")
|
||||||
|
self._stats.processed += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
# Play audio
|
||||||
|
self.audio_player.play(audio_data, self.tts_engine.get_sample_rate())
|
||||||
|
|
||||||
|
# Wait for playback to complete
|
||||||
|
await self.audio_player.wait_async()
|
||||||
|
|
||||||
|
# Update stats
|
||||||
|
duration = len(audio_data) / self.tts_engine.get_sample_rate()
|
||||||
|
self._stats.processed += 1
|
||||||
|
self._stats.total_audio_seconds += duration
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.debug(f"Request processed: {duration:.2f}s audio in {elapsed:.2f}s")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"Request timed out after {self.request_timeout}s")
|
||||||
|
self._stats.errors += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing request: {e}")
|
||||||
|
self._stats.errors += 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Get current queue size."""
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capacity(self) -> int:
|
||||||
|
"""Get queue capacity."""
|
||||||
|
return self.max_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def utilization(self) -> float:
|
||||||
|
"""Get queue utilization percentage."""
|
||||||
|
if self.max_size == 0:
|
||||||
|
return 0.0
|
||||||
|
return (self.size / self.max_size) * 100.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stats(self) -> QueueStats:
|
||||||
|
"""Get queue statistics."""
|
||||||
|
return self._stats
|
||||||
|
|
||||||
|
def get_status(self) -> dict:
|
||||||
|
"""Get queue status for health checks."""
|
||||||
|
return {
|
||||||
|
"size": self.size,
|
||||||
|
"capacity": self.capacity,
|
||||||
|
"utilization": round(self.utilization, 1),
|
||||||
|
"processed": self._stats.processed,
|
||||||
|
"errors": self._stats.errors,
|
||||||
|
"total_audio_seconds": round(self._stats.total_audio_seconds, 1),
|
||||||
|
"running": self._running,
|
||||||
|
}
|
||||||
198
app/routes.py
Normal file
198
app/routes.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
"""
|
||||||
|
API routes for voice-server.
|
||||||
|
|
||||||
|
Defines all HTTP endpoints:
|
||||||
|
- POST /notify: Submit text for TTS playback
|
||||||
|
- GET /health: Health check endpoint
|
||||||
|
- GET /voices: List available voice models
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Response, status
|
||||||
|
|
||||||
|
from app.config import get_settings
|
||||||
|
from app.models import (
|
||||||
|
ErrorResponse,
|
||||||
|
HealthResponse,
|
||||||
|
NotifyRequest,
|
||||||
|
NotifyResponse,
|
||||||
|
QueueStatus,
|
||||||
|
VoiceInfo,
|
||||||
|
VoicesResponse,
|
||||||
|
)
|
||||||
|
from app.queue_manager import QueueFullError, TTSRequest
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_components():
|
||||||
|
"""Get the TTS components from main module."""
|
||||||
|
from app import main
|
||||||
|
|
||||||
|
return main.tts_engine, main.audio_player, main.queue_manager, main.get_uptime_seconds
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/notify",
|
||||||
|
response_model=NotifyResponse,
|
||||||
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
|
responses={
|
||||||
|
422: {"model": ErrorResponse, "description": "Validation error"},
|
||||||
|
503: {"model": ErrorResponse, "description": "Queue full"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def notify(request: NotifyRequest) -> NotifyResponse:
|
||||||
|
"""
|
||||||
|
Submit text for TTS playback.
|
||||||
|
|
||||||
|
Accepts a text message and queues it for text-to-speech conversion
|
||||||
|
and playback through the system speakers.
|
||||||
|
|
||||||
|
Returns immediately with queue position; audio plays asynchronously.
|
||||||
|
"""
|
||||||
|
tts_engine, audio_player, queue_manager, _ = _get_components()
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Validate voice exists if specified
|
||||||
|
voice = request.voice or settings.default_voice
|
||||||
|
if tts_engine and not tts_engine.is_voice_available(voice):
|
||||||
|
available = [v["name"] for v in tts_engine.list_voices()]
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail=f"Voice '{voice}' not found. Available: {available}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create TTS request
|
||||||
|
tts_request = TTSRequest(
|
||||||
|
message=request.message,
|
||||||
|
voice=voice,
|
||||||
|
rate=request.rate,
|
||||||
|
voice_enabled=request.voice_enabled and settings.voice_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enqueue request
|
||||||
|
try:
|
||||||
|
if queue_manager:
|
||||||
|
position = await queue_manager.enqueue(tts_request)
|
||||||
|
else:
|
||||||
|
position = 1 # Fallback for testing
|
||||||
|
except QueueFullError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
return NotifyResponse(
|
||||||
|
status="queued",
|
||||||
|
message_length=len(request.message),
|
||||||
|
queue_position=position,
|
||||||
|
voice_model=voice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/health",
|
||||||
|
response_model=HealthResponse,
|
||||||
|
responses={
|
||||||
|
503: {"model": HealthResponse, "description": "Service unhealthy"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def health(response: Response) -> HealthResponse:
|
||||||
|
"""
|
||||||
|
Health check endpoint.
|
||||||
|
|
||||||
|
Returns comprehensive health status including:
|
||||||
|
- TTS engine status
|
||||||
|
- Audio output status
|
||||||
|
- Queue status
|
||||||
|
- System metrics
|
||||||
|
"""
|
||||||
|
tts_engine, audio_player, queue_manager, get_uptime_seconds = _get_components()
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
# Check TTS engine health
|
||||||
|
tts_status = "unknown"
|
||||||
|
if tts_engine:
|
||||||
|
tts_health = tts_engine.health_check()
|
||||||
|
tts_status = tts_health.get("status", "unknown")
|
||||||
|
if tts_status != "healthy":
|
||||||
|
errors.append(f"TTS: {tts_health.get('error', 'Unknown error')}")
|
||||||
|
|
||||||
|
# Check audio health
|
||||||
|
audio_status = "unknown"
|
||||||
|
if audio_player:
|
||||||
|
audio_health = audio_player.health_check()
|
||||||
|
audio_status = audio_health.get("status", "unknown")
|
||||||
|
if audio_status != "healthy":
|
||||||
|
errors.append(f"Audio: {audio_health.get('error', 'Unknown error')}")
|
||||||
|
|
||||||
|
# Get queue status
|
||||||
|
if queue_manager:
|
||||||
|
queue_info = queue_manager.get_status()
|
||||||
|
queue_status = QueueStatus(
|
||||||
|
size=queue_info["size"],
|
||||||
|
capacity=queue_info["capacity"],
|
||||||
|
utilization=queue_info["utilization"],
|
||||||
|
)
|
||||||
|
total_requests = queue_info["processed"]
|
||||||
|
failed_requests = queue_info["errors"]
|
||||||
|
else:
|
||||||
|
queue_status = QueueStatus(
|
||||||
|
size=0,
|
||||||
|
capacity=settings.queue_max_size,
|
||||||
|
utilization=0.0,
|
||||||
|
)
|
||||||
|
total_requests = None
|
||||||
|
failed_requests = None
|
||||||
|
|
||||||
|
# Determine overall status
|
||||||
|
overall_status = "healthy" if not errors else "unhealthy"
|
||||||
|
|
||||||
|
# Set response status code for unhealthy
|
||||||
|
if overall_status == "unhealthy":
|
||||||
|
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||||
|
|
||||||
|
return HealthResponse(
|
||||||
|
status=overall_status,
|
||||||
|
uptime_seconds=get_uptime_seconds(),
|
||||||
|
queue=queue_status,
|
||||||
|
tts_engine="piper",
|
||||||
|
audio_output="available" if audio_status == "healthy" else "unavailable",
|
||||||
|
total_requests=total_requests,
|
||||||
|
failed_requests=failed_requests,
|
||||||
|
errors=errors if errors else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/voices",
|
||||||
|
response_model=VoicesResponse,
|
||||||
|
)
|
||||||
|
async def list_voices() -> VoicesResponse:
|
||||||
|
"""
|
||||||
|
List available voice models.
|
||||||
|
|
||||||
|
Returns a list of installed voice models with their metadata
|
||||||
|
and the current default voice.
|
||||||
|
"""
|
||||||
|
tts_engine, _, _, _ = _get_components()
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
voices = []
|
||||||
|
if tts_engine:
|
||||||
|
for voice_data in tts_engine.list_voices():
|
||||||
|
voices.append(
|
||||||
|
VoiceInfo(
|
||||||
|
name=voice_data["name"],
|
||||||
|
language=voice_data["language"],
|
||||||
|
quality=voice_data["quality"],
|
||||||
|
size_mb=voice_data["size_mb"],
|
||||||
|
installed=voice_data["installed"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return VoicesResponse(
|
||||||
|
voices=voices,
|
||||||
|
default_voice=settings.default_voice,
|
||||||
|
)
|
||||||
287
app/tts_engine.py
Normal file
287
app/tts_engine.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
"""
|
||||||
|
TTS Engine module for voice-server.
|
||||||
|
|
||||||
|
Provides text-to-speech synthesis using Piper TTS.
|
||||||
|
Supports multiple voice models with lazy loading and caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TTSEngine(Protocol):
|
||||||
|
"""Protocol defining the TTS engine interface."""
|
||||||
|
|
||||||
|
def synthesize(self, text: str, voice: str | None = None) -> np.ndarray:
|
||||||
|
"""Convert text to audio samples."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_sample_rate(self) -> int:
|
||||||
|
"""Get the audio sample rate."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def list_voices(self) -> list[dict]:
|
||||||
|
"""List available voice models."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class PiperTTSEngine:
|
||||||
|
"""
|
||||||
|
Piper TTS engine implementation.
|
||||||
|
|
||||||
|
Provides high-quality neural text-to-speech using Piper's ONNX models.
|
||||||
|
Voice models are loaded lazily and cached for performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_dir: str = "./models", default_voice: str = "en_US-lessac-medium"):
|
||||||
|
"""
|
||||||
|
Initialize the Piper TTS engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dir: Directory containing voice model files (.onnx + .onnx.json)
|
||||||
|
default_voice: Default voice model name to use
|
||||||
|
"""
|
||||||
|
self.model_dir = Path(model_dir)
|
||||||
|
self.default_voice = default_voice
|
||||||
|
self._voices: dict = {} # Cache of loaded PiperVoice instances
|
||||||
|
self._voice_metadata: dict = {} # Cache of voice metadata
|
||||||
|
self._sample_rate: int = 22050 # Piper default sample rate
|
||||||
|
|
||||||
|
# Ensure model directory exists
|
||||||
|
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info(f"PiperTTSEngine initialized with model_dir={model_dir}")
|
||||||
|
|
||||||
|
def _get_voice_path(self, voice_name: str) -> tuple[Path, Path]:
|
||||||
|
"""
|
||||||
|
Get paths to voice model files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_name: Name of the voice model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (onnx_path, json_path)
|
||||||
|
"""
|
||||||
|
onnx_path = self.model_dir / f"{voice_name}.onnx"
|
||||||
|
json_path = self.model_dir / f"{voice_name}.onnx.json"
|
||||||
|
return onnx_path, json_path
|
||||||
|
|
||||||
|
def _load_voice(self, voice_name: str):
|
||||||
|
"""
|
||||||
|
Load a voice model (lazy loading with caching).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_name: Name of the voice model to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded PiperVoice instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If voice model files don't exist
|
||||||
|
RuntimeError: If voice model fails to load
|
||||||
|
"""
|
||||||
|
if voice_name in self._voices:
|
||||||
|
return self._voices[voice_name]
|
||||||
|
|
||||||
|
onnx_path, json_path = self._get_voice_path(voice_name)
|
||||||
|
|
||||||
|
if not onnx_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Voice model not found: {voice_name}. "
|
||||||
|
f"Expected file: {onnx_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from piper import PiperVoice
|
||||||
|
|
||||||
|
logger.info(f"Loading voice model: {voice_name}")
|
||||||
|
voice = PiperVoice.load(str(onnx_path), config_path=str(json_path) if json_path.exists() else None)
|
||||||
|
self._voices[voice_name] = voice
|
||||||
|
|
||||||
|
# Update sample rate from loaded voice
|
||||||
|
if hasattr(voice, 'config') and voice.config:
|
||||||
|
self._sample_rate = voice.config.sample_rate
|
||||||
|
|
||||||
|
logger.info(f"Voice model loaded: {voice_name} (sample_rate={self._sample_rate})")
|
||||||
|
return voice
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load voice model {voice_name}: {e}")
|
||||||
|
raise RuntimeError(f"Failed to load voice model: {e}") from e
|
||||||
|
|
||||||
|
def synthesize(self, text: str, voice: str | None = None) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert text to audio samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to convert to speech
|
||||||
|
voice: Voice model name (uses default if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NumPy array of audio samples (int16)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If voice model not found
|
||||||
|
RuntimeError: If synthesis fails
|
||||||
|
"""
|
||||||
|
voice_name = voice or self.default_voice
|
||||||
|
|
||||||
|
if not text or not text.strip():
|
||||||
|
# Return empty audio for empty text
|
||||||
|
return np.array([], dtype=np.int16)
|
||||||
|
|
||||||
|
try:
|
||||||
|
piper_voice = self._load_voice(voice_name)
|
||||||
|
|
||||||
|
# Synthesize audio - piper returns an iterator of AudioChunk objects
|
||||||
|
audio_chunks = []
|
||||||
|
for chunk in piper_voice.synthesize(text):
|
||||||
|
# Each chunk has audio_int16_array property
|
||||||
|
audio_chunks.append(chunk.audio_int16_array)
|
||||||
|
|
||||||
|
if not audio_chunks:
|
||||||
|
return np.array([], dtype=np.int16)
|
||||||
|
|
||||||
|
# Concatenate all chunks
|
||||||
|
audio_array = np.concatenate(audio_chunks)
|
||||||
|
|
||||||
|
logger.debug(f"Synthesized {len(text)} chars -> {len(audio_array)} samples")
|
||||||
|
return audio_array
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"TTS synthesis failed: {e}")
|
||||||
|
raise RuntimeError(f"TTS synthesis failed: {e}") from e
|
||||||
|
|
||||||
|
def synthesize_to_float32(self, text: str, voice: str | None = None) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert text to float32 audio samples (normalized -1.0 to 1.0).
|
||||||
|
|
||||||
|
This format is preferred by sounddevice for playback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to convert to speech
|
||||||
|
voice: Voice model name (uses default if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NumPy array of float32 audio samples
|
||||||
|
"""
|
||||||
|
int16_audio = self.synthesize(text, voice)
|
||||||
|
|
||||||
|
if len(int16_audio) == 0:
|
||||||
|
return np.array([], dtype=np.float32)
|
||||||
|
|
||||||
|
# Convert int16 to float32 normalized
|
||||||
|
float32_audio = int16_audio.astype(np.float32) / 32768.0
|
||||||
|
return float32_audio
|
||||||
|
|
||||||
|
def get_sample_rate(self) -> int:
|
||||||
|
"""Get the audio sample rate for the current voice."""
|
||||||
|
return self._sample_rate
|
||||||
|
|
||||||
|
def list_voices(self) -> list[dict]:
|
||||||
|
"""
|
||||||
|
List available voice models in the model directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of voice info dictionaries with name, language, quality, etc.
|
||||||
|
"""
|
||||||
|
voices = []
|
||||||
|
|
||||||
|
if not self.model_dir.exists():
|
||||||
|
return voices
|
||||||
|
|
||||||
|
# Find all .onnx files
|
||||||
|
for onnx_file in self.model_dir.glob("*.onnx"):
|
||||||
|
voice_name = onnx_file.stem
|
||||||
|
json_file = onnx_file.with_suffix(".onnx.json")
|
||||||
|
|
||||||
|
voice_info = {
|
||||||
|
"name": voice_name,
|
||||||
|
"language": self._extract_language(voice_name),
|
||||||
|
"quality": self._extract_quality(voice_name),
|
||||||
|
"size_mb": round(onnx_file.stat().st_size / (1024 * 1024), 1),
|
||||||
|
"installed": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try to load additional metadata from JSON config
|
||||||
|
if json_file.exists():
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
with open(json_file) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
if "language" in config:
|
||||||
|
voice_info["language"] = config["language"].get("code", voice_info["language"])
|
||||||
|
except Exception:
|
||||||
|
pass # Use extracted values if JSON parsing fails
|
||||||
|
|
||||||
|
voices.append(voice_info)
|
||||||
|
|
||||||
|
return sorted(voices, key=lambda v: v["name"])
|
||||||
|
|
||||||
|
def _extract_language(self, voice_name: str) -> str:
|
||||||
|
"""Extract language code from voice name (e.g., 'en_US' from 'en_US-lessac-medium')."""
|
||||||
|
parts = voice_name.split("-")
|
||||||
|
if parts:
|
||||||
|
return parts[0]
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def _extract_quality(self, voice_name: str) -> str:
|
||||||
|
"""Extract quality level from voice name (e.g., 'medium' from 'en_US-lessac-medium')."""
|
||||||
|
parts = voice_name.split("-")
|
||||||
|
if len(parts) >= 3:
|
||||||
|
quality = parts[-1].lower()
|
||||||
|
if quality in ("low", "medium", "high", "x_low", "x_high"):
|
||||||
|
return quality
|
||||||
|
return "medium"
|
||||||
|
|
||||||
|
def is_voice_available(self, voice_name: str) -> bool:
|
||||||
|
"""Check if a voice model is installed."""
|
||||||
|
onnx_path, _ = self._get_voice_path(voice_name)
|
||||||
|
return onnx_path.exists()
|
||||||
|
|
||||||
|
def health_check(self) -> dict:
|
||||||
|
"""
|
||||||
|
Perform a health check on the TTS engine.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with status and any error messages
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if piper is importable
|
||||||
|
from piper import PiperVoice # noqa: F401
|
||||||
|
|
||||||
|
# Check if model directory exists
|
||||||
|
if not self.model_dir.exists():
|
||||||
|
return {
|
||||||
|
"status": "degraded",
|
||||||
|
"error": f"Model directory does not exist: {self.model_dir}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if default voice is available
|
||||||
|
if not self.is_voice_available(self.default_voice):
|
||||||
|
available = [v["name"] for v in self.list_voices()]
|
||||||
|
return {
|
||||||
|
"status": "degraded",
|
||||||
|
"error": f"Default voice not found: {self.default_voice}",
|
||||||
|
"available_voices": available,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
return {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"error": f"Piper TTS not installed: {e}",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
69
pyproject.toml
Normal file
69
pyproject.toml
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
[project]
|
||||||
|
name = "voice-server"
|
||||||
|
version = "1.0.0"
|
||||||
|
description = "Local HTTP service for text-to-speech playback"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
license = {text = "MIT"}
|
||||||
|
authors = [
|
||||||
|
{name = "Cal Corum", email = "cal.corum@gmail.com"}
|
||||||
|
]
|
||||||
|
keywords = ["tts", "text-to-speech", "piper", "fastapi", "voice"]
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.115.0",
|
||||||
|
"uvicorn[standard]>=0.32.0",
|
||||||
|
"piper-tts>=1.2.0",
|
||||||
|
"sounddevice>=0.5.0",
|
||||||
|
"numpy>=1.26.0",
|
||||||
|
"pydantic>=2.10.0",
|
||||||
|
"pydantic-settings>=2.6.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
|
"psutil>=6.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.3.0",
|
||||||
|
"pytest-asyncio>=0.24.0",
|
||||||
|
"pytest-cov>=6.0.0",
|
||||||
|
"httpx>=0.28.0",
|
||||||
|
"ruff>=0.8.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
voice-server = "app.main:run"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["app"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
python_functions = ["test_*"]
|
||||||
|
addopts = "-v --tb=short"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
target-version = "py310"
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["E", "F", "I", "N", "W", "UP"]
|
||||||
|
ignore = ["E501"]
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
source = ["app"]
|
||||||
|
omit = ["tests/*"]
|
||||||
|
|
||||||
|
[tool.coverage.report]
|
||||||
|
exclude_lines = [
|
||||||
|
"pragma: no cover",
|
||||||
|
"if TYPE_CHECKING:",
|
||||||
|
"raise NotImplementedError",
|
||||||
|
]
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
324
tests/test_api.py
Normal file
324
tests/test_api.py
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
"""
|
||||||
|
TDD Tests for API endpoints.
|
||||||
|
|
||||||
|
These tests verify the API contract for all voice-server endpoints:
|
||||||
|
- POST /notify: TTS request submission
|
||||||
|
- GET /health: Health check
|
||||||
|
- GET /voices: Voice model listing
|
||||||
|
|
||||||
|
Uses httpx.AsyncClient for async endpoint testing with FastAPI's TestClient.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client():
|
||||||
|
"""Create an async test client for the FastAPI app."""
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=app),
|
||||||
|
base_url="http://test",
|
||||||
|
) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotifyEndpoint:
|
||||||
|
"""Tests for POST /notify endpoint."""
|
||||||
|
|
||||||
|
async def test_valid_request_returns_202(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
A valid request with just a message should return 202 Accepted.
|
||||||
|
"""
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
json={"message": "Hello, world!"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 202
|
||||||
|
|
||||||
|
async def test_valid_request_returns_queued_status(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Response should include status='queued' for successful requests.
|
||||||
|
"""
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
json={"message": "Test message"},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "queued"
|
||||||
|
|
||||||
|
async def test_response_includes_message_length(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Response should include the length of the submitted message.
|
||||||
|
"""
|
||||||
|
message = "This is a test message"
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
json={"message": message},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert data["message_length"] == len(message)
|
||||||
|
|
||||||
|
async def test_response_includes_queue_position(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Response should include the queue position.
|
||||||
|
"""
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
json={"message": "Test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "queue_position" in data
|
||||||
|
assert isinstance(data["queue_position"], int)
|
||||||
|
assert data["queue_position"] >= 1
|
||||||
|
|
||||||
|
async def test_response_includes_voice_model(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Response should include the voice model being used.
|
||||||
|
"""
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
json={"message": "Test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "voice_model" in data
|
||||||
|
assert data["voice_model"] == "en_US-lessac-medium" # default
|
||||||
|
|
||||||
|
async def test_custom_voice_is_preserved(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Custom voice selection should be reflected in response.
|
||||||
|
"""
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
json={"message": "Test", "voice": "en_US-libritts-high"},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert data["voice_model"] == "en_US-libritts-high"
|
||||||
|
|
||||||
|
async def test_missing_message_returns_422(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Request without message should return 422 Unprocessable Entity.
|
||||||
|
"""
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
json={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
async def test_empty_message_returns_422(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Empty message string should return 422 Unprocessable Entity.
|
||||||
|
"""
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
json={"message": ""},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
async def test_message_too_long_returns_422(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Message over 10000 characters should return 422.
|
||||||
|
"""
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
json={"message": "a" * 10001},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
async def test_invalid_rate_returns_422(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Rate outside valid range should return 422.
|
||||||
|
"""
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
json={"message": "Test", "rate": 500},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
async def test_invalid_voice_pattern_returns_422(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Voice with invalid characters should return 422.
|
||||||
|
"""
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
json={"message": "Test", "voice": "invalid/voice"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
async def test_malformed_json_returns_422(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Malformed JSON should return 422.
|
||||||
|
"""
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
content="not valid json",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
async def test_whitespace_message_is_stripped(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Whitespace in message should be stripped.
|
||||||
|
"""
|
||||||
|
response = await client.post(
|
||||||
|
"/notify",
|
||||||
|
json={"message": " Hello "},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 202
|
||||||
|
data = response.json()
|
||||||
|
assert data["message_length"] == 5 # "Hello" without whitespace
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthEndpoint:
|
||||||
|
"""Tests for GET /health endpoint."""
|
||||||
|
|
||||||
|
async def test_health_returns_200(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Health endpoint should return 200 when healthy.
|
||||||
|
"""
|
||||||
|
response = await client.get("/health")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
async def test_health_returns_status(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Health response should include status field.
|
||||||
|
"""
|
||||||
|
response = await client.get("/health")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "status" in data
|
||||||
|
assert data["status"] in ["healthy", "unhealthy"]
|
||||||
|
|
||||||
|
async def test_health_returns_uptime(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Health response should include uptime in seconds.
|
||||||
|
"""
|
||||||
|
response = await client.get("/health")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "uptime_seconds" in data
|
||||||
|
assert isinstance(data["uptime_seconds"], int)
|
||||||
|
assert data["uptime_seconds"] >= 0
|
||||||
|
|
||||||
|
async def test_health_returns_queue_status(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Health response should include queue status.
|
||||||
|
"""
|
||||||
|
response = await client.get("/health")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "queue" in data
|
||||||
|
assert "size" in data["queue"]
|
||||||
|
assert "capacity" in data["queue"]
|
||||||
|
assert "utilization" in data["queue"]
|
||||||
|
|
||||||
|
async def test_health_returns_tts_engine(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Health response should include TTS engine info.
|
||||||
|
"""
|
||||||
|
response = await client.get("/health")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "tts_engine" in data
|
||||||
|
assert data["tts_engine"] == "piper"
|
||||||
|
|
||||||
|
async def test_health_returns_audio_output(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Health response should include audio output status.
|
||||||
|
"""
|
||||||
|
response = await client.get("/health")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "audio_output" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestVoicesEndpoint:
|
||||||
|
"""Tests for GET /voices endpoint."""
|
||||||
|
|
||||||
|
async def test_voices_returns_200(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Voices endpoint should return 200.
|
||||||
|
"""
|
||||||
|
response = await client.get("/voices")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
async def test_voices_returns_list(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Voices response should include a list of voices.
|
||||||
|
"""
|
||||||
|
response = await client.get("/voices")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "voices" in data
|
||||||
|
assert isinstance(data["voices"], list)
|
||||||
|
|
||||||
|
async def test_voices_returns_default_voice(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Voices response should include the default voice.
|
||||||
|
"""
|
||||||
|
response = await client.get("/voices")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "default_voice" in data
|
||||||
|
assert data["default_voice"] == "en_US-lessac-medium"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAPIDocumentation:
|
||||||
|
"""Tests for API documentation endpoints."""
|
||||||
|
|
||||||
|
async def test_openapi_json_available(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
OpenAPI JSON should be available at /openapi.json.
|
||||||
|
"""
|
||||||
|
response = await client.get("/openapi.json")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "openapi" in data
|
||||||
|
assert "paths" in data
|
||||||
|
|
||||||
|
async def test_docs_endpoint_available(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
Swagger UI should be available at /docs.
|
||||||
|
"""
|
||||||
|
response = await client.get("/docs")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "text/html" in response.headers.get("content-type", "")
|
||||||
|
|
||||||
|
|
||||||
|
class TestCORS:
|
||||||
|
"""Tests for CORS middleware."""
|
||||||
|
|
||||||
|
async def test_cors_headers_present(self, client: AsyncClient):
|
||||||
|
"""
|
||||||
|
CORS headers should be present in responses.
|
||||||
|
"""
|
||||||
|
response = await client.options(
|
||||||
|
"/notify",
|
||||||
|
headers={
|
||||||
|
"Origin": "http://localhost:3000",
|
||||||
|
"Access-Control-Request-Method": "POST",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# FastAPI returns 200 for OPTIONS with CORS
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "access-control-allow-origin" in response.headers
|
||||||
300
tests/test_config.py
Normal file
300
tests/test_config.py
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
"""
|
||||||
|
TDD Tests for configuration loading.
|
||||||
|
|
||||||
|
These tests define the expected behavior for the Settings class which loads
|
||||||
|
configuration from environment variables with sensible defaults.
|
||||||
|
|
||||||
|
Test Coverage:
|
||||||
|
- Default values when no environment variables are set
|
||||||
|
- Environment variable overrides
|
||||||
|
- Validation of configuration values
|
||||||
|
- Path handling for model directory
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingsDefaults:
|
||||||
|
"""Tests for default configuration values."""
|
||||||
|
|
||||||
|
def test_default_host(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Default host should be 0.0.0.0 (listen on all interfaces).
|
||||||
|
"""
|
||||||
|
# Clear any existing env vars
|
||||||
|
monkeypatch.delenv("HOST", raising=False)
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.host == "0.0.0.0"
|
||||||
|
|
||||||
|
def test_default_port(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Default port should be 8888.
|
||||||
|
"""
|
||||||
|
monkeypatch.delenv("PORT", raising=False)
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.port == 8888
|
||||||
|
|
||||||
|
def test_default_model_dir(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Default model directory should be ./models.
|
||||||
|
"""
|
||||||
|
monkeypatch.delenv("MODEL_DIR", raising=False)
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.model_dir == "./models"
|
||||||
|
|
||||||
|
def test_default_voice(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Default voice should be en_US-lessac-medium.
|
||||||
|
"""
|
||||||
|
monkeypatch.delenv("DEFAULT_VOICE", raising=False)
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.default_voice == "en_US-lessac-medium"
|
||||||
|
|
||||||
|
def test_default_rate(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Default speech rate should be 170 WPM.
|
||||||
|
"""
|
||||||
|
monkeypatch.delenv("DEFAULT_RATE", raising=False)
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.default_rate == 170
|
||||||
|
|
||||||
|
def test_default_queue_max_size(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Default queue max size should be 50.
|
||||||
|
"""
|
||||||
|
monkeypatch.delenv("QUEUE_MAX_SIZE", raising=False)
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.queue_max_size == 50
|
||||||
|
|
||||||
|
def test_default_request_timeout(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Default request timeout should be 60 seconds.
|
||||||
|
"""
|
||||||
|
monkeypatch.delenv("REQUEST_TIMEOUT_SECONDS", raising=False)
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.request_timeout_seconds == 60
|
||||||
|
|
||||||
|
def test_default_log_level(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Default log level should be INFO.
|
||||||
|
"""
|
||||||
|
monkeypatch.delenv("LOG_LEVEL", raising=False)
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.log_level == "INFO"
|
||||||
|
|
||||||
|
def test_default_voice_enabled(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Voice should be enabled by default.
|
||||||
|
"""
|
||||||
|
monkeypatch.delenv("VOICE_ENABLED", raising=False)
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.voice_enabled is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingsEnvOverrides:
|
||||||
|
"""Tests for environment variable overrides."""
|
||||||
|
|
||||||
|
def test_host_override(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
HOST environment variable should override default.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("HOST", "127.0.0.1")
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.host == "127.0.0.1"
|
||||||
|
|
||||||
|
def test_port_override(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
PORT environment variable should override default.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("PORT", "9000")
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.port == 9000
|
||||||
|
|
||||||
|
def test_model_dir_override(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
MODEL_DIR environment variable should override default.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("MODEL_DIR", "/opt/voice-models")
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.model_dir == "/opt/voice-models"
|
||||||
|
|
||||||
|
def test_default_voice_override(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
DEFAULT_VOICE environment variable should override default.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("DEFAULT_VOICE", "en_US-libritts-high")
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.default_voice == "en_US-libritts-high"
|
||||||
|
|
||||||
|
def test_queue_max_size_override(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
QUEUE_MAX_SIZE environment variable should override default.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("QUEUE_MAX_SIZE", "100")
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.queue_max_size == 100
|
||||||
|
|
||||||
|
def test_log_level_override(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
LOG_LEVEL environment variable should override default.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("LOG_LEVEL", "DEBUG")
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.log_level == "DEBUG"
|
||||||
|
|
||||||
|
def test_voice_enabled_false(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
VOICE_ENABLED=false should disable voice output.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("VOICE_ENABLED", "false")
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.voice_enabled is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingsValidation:
|
||||||
|
"""Tests for configuration validation."""
|
||||||
|
|
||||||
|
def test_port_must_be_positive(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Port must be a positive integer.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("PORT", "-1")
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_port_must_be_valid_range(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Port must be in valid range (1-65535).
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("PORT", "70000")
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_queue_max_size_must_be_positive(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Queue max size must be positive.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("QUEUE_MAX_SIZE", "0")
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_request_timeout_must_be_positive(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Request timeout must be positive.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("REQUEST_TIMEOUT_SECONDS", "0")
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_default_rate_must_be_in_range(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Default rate must be between 50 and 400.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("DEFAULT_RATE", "500")
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_log_level_must_be_valid(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Log level must be a valid Python logging level.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("LOG_LEVEL", "INVALID")
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSettings:
|
||||||
|
"""Tests for the get_settings function."""
|
||||||
|
|
||||||
|
def test_get_settings_returns_settings_instance(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
get_settings should return a Settings instance.
|
||||||
|
"""
|
||||||
|
# Clear cache to ensure fresh settings
|
||||||
|
from app.config import get_settings, Settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
assert isinstance(settings, Settings)
|
||||||
|
|
||||||
|
def test_get_settings_is_cached(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
get_settings should return the same cached instance.
|
||||||
|
"""
|
||||||
|
from app.config import get_settings
|
||||||
|
|
||||||
|
settings1 = get_settings()
|
||||||
|
settings2 = get_settings()
|
||||||
|
assert settings1 is settings2
|
||||||
388
tests/test_models.py
Normal file
388
tests/test_models.py
Normal file
@ -0,0 +1,388 @@
|
|||||||
|
"""
|
||||||
|
TDD Tests for Pydantic request/response models.
|
||||||
|
|
||||||
|
These tests define the API contract for the voice server's request and response models.
|
||||||
|
Tests are written BEFORE implementation to drive the design.
|
||||||
|
|
||||||
|
Test Coverage:
|
||||||
|
- NotifyRequest: Validates incoming TTS requests with message, voice, rate, voice_enabled
|
||||||
|
- NotifyResponse: Validates successful queue responses
|
||||||
|
- HealthResponse: Validates health check responses
|
||||||
|
- ErrorResponse: Validates error response format
|
||||||
|
- VoiceInfo/VoicesResponse: Validates voice listing responses
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotifyRequest:
|
||||||
|
"""Tests for the NotifyRequest model - validates incoming TTS requests."""
|
||||||
|
|
||||||
|
def test_valid_request_with_message_only(self):
|
||||||
|
"""
|
||||||
|
A minimal valid request should only require the message field.
|
||||||
|
All other fields should use sensible defaults.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
request = NotifyRequest(message="Hello, world!")
|
||||||
|
|
||||||
|
assert request.message == "Hello, world!"
|
||||||
|
assert request.voice == "en_US-lessac-medium" # default voice
|
||||||
|
assert request.rate == 170 # default rate
|
||||||
|
assert request.voice_enabled is True # default enabled
|
||||||
|
|
||||||
|
def test_valid_request_with_all_fields(self):
|
||||||
|
"""
|
||||||
|
A request with all fields specified should preserve those values.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
request = NotifyRequest(
|
||||||
|
message="Test message",
|
||||||
|
voice="en_US-libritts-high",
|
||||||
|
rate=200,
|
||||||
|
voice_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert request.message == "Test message"
|
||||||
|
assert request.voice == "en_US-libritts-high"
|
||||||
|
assert request.rate == 200
|
||||||
|
assert request.voice_enabled is False
|
||||||
|
|
||||||
|
def test_message_is_required(self):
|
||||||
|
"""
|
||||||
|
The message field is required - omitting it should raise ValidationError.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
NotifyRequest()
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any(e["loc"] == ("message",) and e["type"] == "missing" for e in errors)
|
||||||
|
|
||||||
|
def test_message_cannot_be_empty(self):
|
||||||
|
"""
|
||||||
|
An empty message string should be rejected.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
NotifyRequest(message="")
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any("message" in str(e["loc"]) for e in errors)
|
||||||
|
|
||||||
|
def test_message_minimum_length_is_1(self):
|
||||||
|
"""
|
||||||
|
A single character message should be valid.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
request = NotifyRequest(message="X")
|
||||||
|
assert request.message == "X"
|
||||||
|
|
||||||
|
def test_message_maximum_length_is_10000(self):
|
||||||
|
"""
|
||||||
|
Messages up to 10,000 characters should be accepted.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
long_message = "a" * 10000
|
||||||
|
request = NotifyRequest(message=long_message)
|
||||||
|
assert len(request.message) == 10000
|
||||||
|
|
||||||
|
def test_message_over_10000_characters_rejected(self):
|
||||||
|
"""
|
||||||
|
Messages over 10,000 characters should be rejected.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
too_long = "a" * 10001
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
NotifyRequest(message=too_long)
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any("message" in str(e["loc"]) for e in errors)
|
||||||
|
|
||||||
|
def test_message_whitespace_is_stripped(self):
|
||||||
|
"""
|
||||||
|
Leading and trailing whitespace should be stripped from messages.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
request = NotifyRequest(message=" Hello, world! ")
|
||||||
|
assert request.message == "Hello, world!"
|
||||||
|
|
||||||
|
def test_rate_minimum_is_50(self):
|
||||||
|
"""
|
||||||
|
Rate below 50 should be rejected.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
NotifyRequest(message="Test", rate=49)
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any("rate" in str(e["loc"]) for e in errors)
|
||||||
|
|
||||||
|
def test_rate_maximum_is_400(self):
|
||||||
|
"""
|
||||||
|
Rate above 400 should be rejected.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
NotifyRequest(message="Test", rate=401)
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any("rate" in str(e["loc"]) for e in errors)
|
||||||
|
|
||||||
|
def test_rate_at_boundaries(self):
|
||||||
|
"""
|
||||||
|
Rate values at exact boundaries (50, 400) should be valid.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
request_min = NotifyRequest(message="Test", rate=50)
|
||||||
|
assert request_min.rate == 50
|
||||||
|
|
||||||
|
request_max = NotifyRequest(message="Test", rate=400)
|
||||||
|
assert request_max.rate == 400
|
||||||
|
|
||||||
|
def test_voice_pattern_validation(self):
|
||||||
|
"""
|
||||||
|
Voice names should match expected pattern (alphanumeric, underscores, hyphens).
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
# Valid patterns
|
||||||
|
request = NotifyRequest(message="Test", voice="en_US-lessac-medium")
|
||||||
|
assert request.voice == "en_US-lessac-medium"
|
||||||
|
|
||||||
|
request2 = NotifyRequest(message="Test", voice="voice_123")
|
||||||
|
assert request2.voice == "voice_123"
|
||||||
|
|
||||||
|
def test_invalid_voice_pattern_rejected(self):
|
||||||
|
"""
|
||||||
|
Voice names with invalid characters should be rejected.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyRequest
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
NotifyRequest(message="Test", voice="invalid/voice")
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
NotifyRequest(message="Test", voice="invalid voice")
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotifyResponse:
|
||||||
|
"""Tests for the NotifyResponse model - returned when request is queued."""
|
||||||
|
|
||||||
|
def test_successful_response_structure(self):
|
||||||
|
"""
|
||||||
|
A successful response should contain status, message_length, queue_position.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyResponse
|
||||||
|
|
||||||
|
response = NotifyResponse(
|
||||||
|
status="queued",
|
||||||
|
message_length=42,
|
||||||
|
queue_position=3,
|
||||||
|
voice_model="en_US-lessac-medium",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == "queued"
|
||||||
|
assert response.message_length == 42
|
||||||
|
assert response.queue_position == 3
|
||||||
|
assert response.voice_model == "en_US-lessac-medium"
|
||||||
|
|
||||||
|
def test_estimated_duration_is_optional(self):
|
||||||
|
"""
|
||||||
|
Estimated duration can be omitted.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyResponse
|
||||||
|
|
||||||
|
response = NotifyResponse(
|
||||||
|
status="queued",
|
||||||
|
message_length=42,
|
||||||
|
queue_position=1,
|
||||||
|
voice_model="en_US-lessac-medium",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.estimated_duration is None
|
||||||
|
|
||||||
|
def test_estimated_duration_when_provided(self):
|
||||||
|
"""
|
||||||
|
Estimated duration should be preserved when provided.
|
||||||
|
"""
|
||||||
|
from app.models import NotifyResponse
|
||||||
|
|
||||||
|
response = NotifyResponse(
|
||||||
|
status="queued",
|
||||||
|
message_length=42,
|
||||||
|
queue_position=1,
|
||||||
|
voice_model="en_US-lessac-medium",
|
||||||
|
estimated_duration=2.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.estimated_duration == 2.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthResponse:
|
||||||
|
"""Tests for the HealthResponse model - returned by /health endpoint."""
|
||||||
|
|
||||||
|
def test_healthy_response_structure(self):
|
||||||
|
"""
|
||||||
|
A healthy response should contain all required fields.
|
||||||
|
"""
|
||||||
|
from app.models import HealthResponse, QueueStatus
|
||||||
|
|
||||||
|
queue_status = QueueStatus(size=2, capacity=50, utilization=4.0)
|
||||||
|
response = HealthResponse(
|
||||||
|
status="healthy",
|
||||||
|
uptime_seconds=3600,
|
||||||
|
queue=queue_status,
|
||||||
|
tts_engine="piper",
|
||||||
|
audio_output="available",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == "healthy"
|
||||||
|
assert response.uptime_seconds == 3600
|
||||||
|
assert response.queue.size == 2
|
||||||
|
assert response.queue.capacity == 50
|
||||||
|
assert response.tts_engine == "piper"
|
||||||
|
assert response.audio_output == "available"
|
||||||
|
|
||||||
|
def test_unhealthy_response_with_errors(self):
|
||||||
|
"""
|
||||||
|
An unhealthy response can include error messages.
|
||||||
|
"""
|
||||||
|
from app.models import HealthResponse, QueueStatus
|
||||||
|
|
||||||
|
queue_status = QueueStatus(size=0, capacity=50, utilization=0.0)
|
||||||
|
response = HealthResponse(
|
||||||
|
status="unhealthy",
|
||||||
|
uptime_seconds=100,
|
||||||
|
queue=queue_status,
|
||||||
|
tts_engine="piper",
|
||||||
|
audio_output="unavailable",
|
||||||
|
errors=["Audio device not found", "TTS engine failed to initialize"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == "unhealthy"
|
||||||
|
assert len(response.errors) == 2
|
||||||
|
assert "Audio device not found" in response.errors
|
||||||
|
|
||||||
|
def test_statistics_fields_are_optional(self):
|
||||||
|
"""
|
||||||
|
Statistics like total_requests and failed_requests are optional.
|
||||||
|
"""
|
||||||
|
from app.models import HealthResponse, QueueStatus
|
||||||
|
|
||||||
|
queue_status = QueueStatus(size=0, capacity=50, utilization=0.0)
|
||||||
|
response = HealthResponse(
|
||||||
|
status="healthy",
|
||||||
|
uptime_seconds=0,
|
||||||
|
queue=queue_status,
|
||||||
|
tts_engine="piper",
|
||||||
|
audio_output="available",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.total_requests is None
|
||||||
|
assert response.failed_requests is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorResponse:
|
||||||
|
"""Tests for the ErrorResponse model - returned for error conditions."""
|
||||||
|
|
||||||
|
def test_error_response_structure(self):
|
||||||
|
"""
|
||||||
|
An error response should contain error type, detail, and timestamp.
|
||||||
|
"""
|
||||||
|
from app.models import ErrorResponse
|
||||||
|
|
||||||
|
response = ErrorResponse(
|
||||||
|
error="validation_error",
|
||||||
|
detail="message field is required",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.error == "validation_error"
|
||||||
|
assert response.detail == "message field is required"
|
||||||
|
assert response.timestamp is not None
|
||||||
|
|
||||||
|
def test_timestamp_auto_generated(self):
|
||||||
|
"""
|
||||||
|
Timestamp should be auto-generated if not provided.
|
||||||
|
"""
|
||||||
|
from app.models import ErrorResponse
|
||||||
|
|
||||||
|
response = ErrorResponse(
|
||||||
|
error="queue_full",
|
||||||
|
detail="TTS queue is full",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response.timestamp, datetime)
|
||||||
|
|
||||||
|
def test_queue_full_error_includes_queue_size(self):
|
||||||
|
"""
|
||||||
|
Queue full errors can include the current queue size.
|
||||||
|
"""
|
||||||
|
from app.models import ErrorResponse
|
||||||
|
|
||||||
|
response = ErrorResponse(
|
||||||
|
error="queue_full",
|
||||||
|
detail="TTS queue is full, please retry later",
|
||||||
|
queue_size=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.queue_size == 50
|
||||||
|
|
||||||
|
|
||||||
|
class TestVoiceModels:
|
||||||
|
"""Tests for voice-related models."""
|
||||||
|
|
||||||
|
def test_voice_info_structure(self):
|
||||||
|
"""
|
||||||
|
VoiceInfo should contain name, language, quality, and installation status.
|
||||||
|
"""
|
||||||
|
from app.models import VoiceInfo
|
||||||
|
|
||||||
|
voice = VoiceInfo(
|
||||||
|
name="en_US-lessac-medium",
|
||||||
|
language="en_US",
|
||||||
|
quality="medium",
|
||||||
|
size_mb=63.5,
|
||||||
|
installed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert voice.name == "en_US-lessac-medium"
|
||||||
|
assert voice.language == "en_US"
|
||||||
|
assert voice.quality == "medium"
|
||||||
|
assert voice.size_mb == 63.5
|
||||||
|
assert voice.installed is True
|
||||||
|
|
||||||
|
def test_voices_response_structure(self):
|
||||||
|
"""
|
||||||
|
VoicesResponse should contain a list of voices and the default voice.
|
||||||
|
"""
|
||||||
|
from app.models import VoiceInfo, VoicesResponse
|
||||||
|
|
||||||
|
voice = VoiceInfo(
|
||||||
|
name="en_US-lessac-medium",
|
||||||
|
language="en_US",
|
||||||
|
quality="medium",
|
||||||
|
size_mb=63.5,
|
||||||
|
installed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = VoicesResponse(
|
||||||
|
voices=[voice],
|
||||||
|
default_voice="en_US-lessac-medium",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.voices) == 1
|
||||||
|
assert response.default_voice == "en_US-lessac-medium"
|
||||||
Loading…
Reference in New Issue
Block a user