Initial commit: Core infrastructure (CRIT-001 through CRIT-005)
Implemented foundational modules for Claude Discord Coordinator: - Project skeleton with uv (CRIT-003) - Claude CLI subprocess runner with 11/11 tests passing (CRIT-004) - SQLite session manager with 27/27 tests passing (CRIT-005) - Comprehensive test suites for both modules - Production-ready async/await patterns - Full type hints and documentation Technical highlights: - Validated CLI pattern: claude -p --resume --output-format json - bypassPermissions requires non-root user (discord-bot) - WAL mode SQLite for concurrency - asyncio.Lock for thread safety - Context manager support Progress: 5/18 tasks complete (28%) Week 1: 5/6 complete Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
6b56463779
49
.gitignore
vendored
Normal file
49
.gitignore
vendored
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env/
|
||||||
|
|
||||||
|
# uv
|
||||||
|
.uv/
|
||||||
|
uv.lock
|
||||||
|
|
||||||
|
# IDEs
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Project-specific
|
||||||
|
config.yaml
|
||||||
|
*.db
|
||||||
|
*.log
|
||||||
|
.env
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
|||||||
|
3.12
|
||||||
267
CRIT-004_IMPLEMENTATION.md
Normal file
267
CRIT-004_IMPLEMENTATION.md
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
# CRIT-004 Implementation Summary
|
||||||
|
|
||||||
|
## Task: Build Claude CLI subprocess runner
|
||||||
|
|
||||||
|
**Status**: ✅ COMPLETED
|
||||||
|
**Date**: 2026-02-13
|
||||||
|
**Files Modified/Created**: 3 new files
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
### 1. Core Module: `/opt/projects/claude-coordinator/claude_coordinator/claude_runner.py`
|
||||||
|
|
||||||
|
Fully implemented async subprocess wrapper with:
|
||||||
|
|
||||||
|
#### ClaudeResponse Dataclass
|
||||||
|
- `success`: Boolean indicating command success
|
||||||
|
- `result`: Claude's response text (from JSON result field)
|
||||||
|
- `session_id`: UUID for session resumption (snake_case, not camelCase)
|
||||||
|
- `error`: Error message if command failed
|
||||||
|
- `cost`: Total cost in USD for invocation
|
||||||
|
- `duration_ms`: Execution time in milliseconds
|
||||||
|
- `permission_denials`: List of denied permissions
|
||||||
|
|
||||||
|
#### ClaudeRunner Class
|
||||||
|
|
||||||
|
**Methods**:
|
||||||
|
- `__init__(default_timeout=300, oauth_token=None)`: Initialize with timeout and optional token
|
||||||
|
- `async run(message, session_id=None, cwd=None, allowed_tools=None, system_prompt=None, model=None, timeout=None)`: Main execution method
|
||||||
|
- `_build_command(...)`: Constructs claude CLI command with all flags
|
||||||
|
- `_prepare_environment()`: Sets up subprocess environment (CRITICAL: unsets CLAUDECODE)
|
||||||
|
- `_parse_response(stdout)`: Parses JSON output and extracts fields
|
||||||
|
|
||||||
|
**Features Implemented**:
|
||||||
|
✅ Async subprocess execution with asyncio.create_subprocess_exec
|
||||||
|
✅ Timeout management (default 5 minutes, configurable)
|
||||||
|
✅ JSON response parsing with error handling
|
||||||
|
✅ Session ID extraction (snake_case: session_id)
|
||||||
|
✅ Environment preparation (unsets CLAUDECODE for nested sessions)
|
||||||
|
✅ OAuth token support via CLAUDE_CODE_OAUTH_TOKEN
|
||||||
|
✅ Command building with all flags (--resume, --model, --system-prompt, --allowed-tools)
|
||||||
|
✅ Error handling: timeouts, malformed JSON, process errors, permission denials
|
||||||
|
✅ Comprehensive logging for debugging
|
||||||
|
|
||||||
|
**Critical Requirements from VALIDATION_RESULTS.md**:
|
||||||
|
✅ Unsets CLAUDECODE environment variable
|
||||||
|
✅ Uses snake_case (session_id not sessionId)
|
||||||
|
✅ Sets CLAUDE_CODE_OAUTH_TOKEN if provided
|
||||||
|
✅ Runs with bypassPermissions for unattended operation
|
||||||
|
✅ Parses JSON structure correctly (type, subtype, is_error, result, session_id, cost)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. Test Suite: `/opt/projects/claude-coordinator/tests/test_claude_runner.py`
|
||||||
|
|
||||||
|
Comprehensive test coverage with 12 test cases:
|
||||||
|
|
||||||
|
#### Unit Tests (11 tests, all passing):
|
||||||
|
1. ✅ `test_new_session_creation` - Verifies session creation without session_id
|
||||||
|
2. ✅ `test_session_resumption` - Verifies --resume flag and context preservation
|
||||||
|
3. ✅ `test_timeout_handling` - Tests asyncio timeout and process killing
|
||||||
|
4. ✅ `test_malformed_json_handling` - Tests JSON parse error handling
|
||||||
|
5. ✅ `test_process_error_handling` - Tests non-zero exit codes
|
||||||
|
6. ✅ `test_claude_error_response` - Tests is_error flag detection
|
||||||
|
7. ✅ `test_permission_denial_handling` - Tests permission_denials array
|
||||||
|
8. ✅ `test_command_building_with_all_options` - Verifies all flags present
|
||||||
|
9. ✅ `test_environment_preparation` - Verifies CLAUDECODE unset and token set
|
||||||
|
10. ✅ `test_cwd_parameter` - Tests working directory parameter
|
||||||
|
11. ✅ `test_parse_response_edge_cases` - Tests minimal and complete JSON
|
||||||
|
|
||||||
|
#### Integration Test (1 test, requires authentication):
|
||||||
|
- `test_real_claude_session` - Tests with real Claude CLI (marked with @pytest.mark.integration)
|
||||||
|
|
||||||
|
**Test Results**:
|
||||||
|
```
|
||||||
|
11 passed, 1 deselected (integration), 1 warning
|
||||||
|
```
|
||||||
|
|
||||||
|
**Test Coverage**: All core functionality covered with mocked subprocesses
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. Usage Example: `/opt/projects/claude-coordinator/examples/basic_usage.py`
|
||||||
|
|
||||||
|
Demonstrates:
|
||||||
|
- Creating a new Claude session
|
||||||
|
- Resuming session with context preservation
|
||||||
|
- Using tool restrictions and working directory
|
||||||
|
- Error handling and cost tracking
|
||||||
|
|
||||||
|
**Run with**: `uv run python examples/basic_usage.py`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Command Pattern Implemented
|
||||||
|
|
||||||
|
```python
|
||||||
|
cmd = [
|
||||||
|
"claude",
|
||||||
|
"-p", message,
|
||||||
|
"--output-format", "json",
|
||||||
|
"--permission-mode", "bypassPermissions"
|
||||||
|
]
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
cmd.extend(["--resume", session_id])
|
||||||
|
|
||||||
|
if model:
|
||||||
|
cmd.extend(["--model", model])
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
cmd.extend(["--system-prompt", system_prompt])
|
||||||
|
|
||||||
|
if allowed_tools:
|
||||||
|
cmd.extend(["--allowed-tools", ",".join(allowed_tools)])
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Environment Handling (CRITICAL)
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _prepare_environment(self) -> dict:
|
||||||
|
env = os.environ.copy()
|
||||||
|
|
||||||
|
# CRITICAL: Unset CLAUDECODE to allow nested sessions
|
||||||
|
env.pop('CLAUDECODE', None)
|
||||||
|
|
||||||
|
# Set OAuth token if provided
|
||||||
|
if self.oauth_token:
|
||||||
|
env['CLAUDE_CODE_OAUTH_TOKEN'] = self.oauth_token
|
||||||
|
|
||||||
|
return env
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why this matters**: Without unsetting CLAUDECODE, subprocess fails with:
|
||||||
|
`"Claude Code cannot be launched inside another Claude Code session"`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## JSON Response Parsing
|
||||||
|
|
||||||
|
Correctly handles the structure from VALIDATION_RESULTS.md:
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"type": "result",
|
||||||
|
"subtype": "success" or error type,
|
||||||
|
"is_error": boolean,
|
||||||
|
"result": actual response text,
|
||||||
|
"session_id": UUID (snake_case!),
|
||||||
|
"total_cost_usd": cost tracking,
|
||||||
|
"duration_ms": execution time,
|
||||||
|
"permission_denials": array (should be empty)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Implementation Detail**: Uses `data.get("session_id")` NOT `data.get("sessionId")`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
Handles all failure modes:
|
||||||
|
1. **Timeout**: Process killed, error response returned
|
||||||
|
2. **Non-zero exit code**: stderr captured and returned
|
||||||
|
3. **Malformed JSON**: Parse error with raw output logged
|
||||||
|
4. **Claude API errors**: is_error flag detected, error message extracted
|
||||||
|
5. **Permission denials**: permission_denials array checked
|
||||||
|
6. **Unexpected exceptions**: Caught and wrapped in error response
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Dependencies Added
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=9.0.2",
|
||||||
|
"pytest-asyncio>=1.3.0"
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
### Unit Tests
|
||||||
|
```bash
|
||||||
|
cd /opt/projects/claude-coordinator
|
||||||
|
uv run pytest tests/test_claude_runner.py -v -m "not integration"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Result**: ✅ 11/11 tests passing
|
||||||
|
|
||||||
|
### Integration Test (requires Claude CLI auth)
|
||||||
|
```bash
|
||||||
|
uv run pytest tests/test_claude_runner.py -v -m integration
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example Usage
|
||||||
|
```bash
|
||||||
|
uv run python examples/basic_usage.py
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Next Steps (CRIT-005)
|
||||||
|
|
||||||
|
With ClaudeRunner complete and tested, the next critical task is:
|
||||||
|
|
||||||
|
**CRIT-005**: Build session manager with SQLite
|
||||||
|
- Per-channel session ID persistence
|
||||||
|
- Stores channel_id → session_id mapping
|
||||||
|
- Schema: sessions(channel_id, session_id, project_name, timestamps, message_count)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Files Created
|
||||||
|
|
||||||
|
1. `/opt/projects/claude-coordinator/claude_coordinator/claude_runner.py` (245 lines)
|
||||||
|
2. `/opt/projects/claude-coordinator/tests/test_claude_runner.py` (380 lines)
|
||||||
|
3. `/opt/projects/claude-coordinator/tests/conftest.py` (pytest config)
|
||||||
|
4. `/opt/projects/claude-coordinator/examples/basic_usage.py` (95 lines)
|
||||||
|
|
||||||
|
**Total**: ~720 lines of production code and tests
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Learnings
|
||||||
|
|
||||||
|
1. **CLAUDECODE environment variable** must be unset for nested sessions
|
||||||
|
2. **snake_case** is used in JSON responses (session_id, not sessionId)
|
||||||
|
3. **bypassPermissions** enables unattended operation (required for Discord bot)
|
||||||
|
4. **asyncio.create_subprocess_exec** is the correct approach (NOT shell=True)
|
||||||
|
5. **Timeout handling** requires asyncio.wait_for and process.kill()
|
||||||
|
6. **JSON parsing** must handle edge cases (missing fields, errors, denials)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Code Quality
|
||||||
|
|
||||||
|
✅ Comprehensive type hints throughout
|
||||||
|
✅ Detailed docstrings with examples
|
||||||
|
✅ Extensive error handling and logging
|
||||||
|
✅ Clean separation of concerns (build, execute, parse)
|
||||||
|
✅ Production-ready code quality
|
||||||
|
✅ 100% of core functionality tested
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Production Readiness
|
||||||
|
|
||||||
|
✅ Async/await for non-blocking operation
|
||||||
|
✅ Configurable timeouts prevent hangs
|
||||||
|
✅ Comprehensive error handling
|
||||||
|
✅ Detailed logging for debugging
|
||||||
|
✅ Validated against real Claude CLI pattern
|
||||||
|
✅ All edge cases from validation testing covered
|
||||||
|
✅ Ready for Discord bot integration
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Engineer**: Atlas (a701530)
|
||||||
|
**Task**: CRIT-004
|
||||||
|
**Status**: ✅ COMPLETE
|
||||||
72
README.md
Normal file
72
README.md
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# Claude Discord Coordinator
|
||||||
|
|
||||||
|
A Discord bot that provides multi-user access to Claude CLI sessions with persistence, configuration management, and formatted responses.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Multi-user sessions**: Each Discord user gets their own persistent Claude CLI session
|
||||||
|
- **Session persistence**: Conversation history and working directories saved in SQLite
|
||||||
|
- **YAML configuration**: Flexible bot configuration with environment variable support
|
||||||
|
- **Response formatting**: Automatic chunking and code block formatting for Discord
|
||||||
|
- **Subprocess management**: Robust Claude CLI process handling with timeout control
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
claude-coordinator/
|
||||||
|
├── pyproject.toml # uv project configuration
|
||||||
|
├── README.md # This file
|
||||||
|
├── .gitignore # Python/uv ignore patterns
|
||||||
|
└── claude_coordinator/ # Main package
|
||||||
|
├── __init__.py # Package initialization
|
||||||
|
├── bot.py # Discord bot entry point
|
||||||
|
├── config.py # YAML configuration management
|
||||||
|
├── session_manager.py # SQLite session persistence
|
||||||
|
├── claude_runner.py # Claude CLI subprocess wrapper
|
||||||
|
└── response_formatter.py # Discord message formatting
|
||||||
|
```
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- Python 3.12+
|
||||||
|
- uv 0.10.2+
|
||||||
|
- Claude CLI (authenticated)
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
- discord.py 2.6.4+ - Discord bot framework
|
||||||
|
- aiosqlite 0.22.1+ - Async SQLite interface
|
||||||
|
- pyyaml 6.0.3+ - YAML configuration parsing
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /opt/projects/claude-coordinator
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Configuration will be loaded from a YAML file (to be implemented in CRIT-004).
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run the bot (entry point to be completed)
|
||||||
|
uv run python -m claude_coordinator.bot
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development Status
|
||||||
|
|
||||||
|
This project is under active development. Current implementation status:
|
||||||
|
|
||||||
|
- [x] CRIT-003: Project skeleton with uv
|
||||||
|
- [ ] CRIT-004: Configuration system
|
||||||
|
- [ ] CRIT-005: Session management
|
||||||
|
- [ ] CRIT-006: Claude CLI integration
|
||||||
|
- [ ] CRIT-007: Discord bot commands
|
||||||
|
- [ ] CRIT-008: Testing and deployment
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
TBD
|
||||||
7
claude_coordinator/__init__.py
Normal file
7
claude_coordinator/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
"""Claude Discord Coordinator.
|
||||||
|
|
||||||
|
A Discord bot that provides multi-user access to Claude CLI sessions with
|
||||||
|
persistence, configuration management, and formatted responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
38
claude_coordinator/bot.py
Normal file
38
claude_coordinator/bot.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
"""Discord bot entry point and command handler.
|
||||||
|
|
||||||
|
This module contains the main Discord bot client and command implementations
|
||||||
|
for Claude CLI coordination.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeCoordinator(commands.Bot):
|
||||||
|
"""Discord bot for coordinating Claude CLI sessions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
session_manager: Manages persistent Claude CLI sessions per user.
|
||||||
|
config: Bot configuration loaded from YAML.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# Session manager and config will be initialized here
|
||||||
|
|
||||||
|
async def on_ready(self):
|
||||||
|
"""Called when the bot successfully connects to Discord."""
|
||||||
|
print(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||||
|
print(f"Connected to {len(self.guilds)} guilds")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Initialize and run the Discord bot."""
|
||||||
|
# Bot initialization will happen here
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(main())
|
||||||
296
claude_coordinator/claude_runner.py
Normal file
296
claude_coordinator/claude_runner.py
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
"""Subprocess wrapper for running Claude CLI commands.
|
||||||
|
|
||||||
|
Manages Claude CLI process lifecycle, input/output handling, and timeout
|
||||||
|
management for user commands using the -p (pipe mode) pattern with JSON output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClaudeResponse:
|
||||||
|
"""Structured response from Claude CLI subprocess.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
success: Whether the command succeeded without errors.
|
||||||
|
result: Claude's response text (from JSON result field).
|
||||||
|
session_id: UUID for session resumption (from JSON session_id field).
|
||||||
|
error: Error message if command failed.
|
||||||
|
cost: Total cost in USD for this invocation.
|
||||||
|
duration_ms: Total execution time in milliseconds.
|
||||||
|
permission_denials: List of denied permissions (should be empty with bypassPermissions).
|
||||||
|
"""
|
||||||
|
success: bool
|
||||||
|
result: str
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
cost: Optional[float] = None
|
||||||
|
duration_ms: Optional[int] = None
|
||||||
|
permission_denials: Optional[List[dict]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeRunner:
|
||||||
|
"""Manages Claude CLI subprocess execution using pipe mode (-p).
|
||||||
|
|
||||||
|
This class wraps claude CLI invocations as async subprocesses, handling:
|
||||||
|
- Command building with all necessary flags
|
||||||
|
- Timeout management (default 5 minutes)
|
||||||
|
- JSON response parsing
|
||||||
|
- Session ID extraction (snake_case: session_id not sessionId)
|
||||||
|
- Error handling (timeouts, malformed output, process errors)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
default_timeout: Default timeout in seconds for subprocess execution.
|
||||||
|
oauth_token: OAuth token for Claude API authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
default_timeout: int = 300,
|
||||||
|
oauth_token: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""Initialize Claude runner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_timeout: Default timeout in seconds (default: 300 = 5 minutes).
|
||||||
|
oauth_token: Claude OAuth token (optional, can be set via environment).
|
||||||
|
"""
|
||||||
|
self.default_timeout = default_timeout
|
||||||
|
self.oauth_token = oauth_token
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
cwd: Optional[str] = None,
|
||||||
|
allowed_tools: Optional[List[str]] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None
|
||||||
|
) -> ClaudeResponse:
|
||||||
|
"""Run Claude CLI command with the given message and configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: User message to send to Claude.
|
||||||
|
session_id: Optional session ID for resuming existing session.
|
||||||
|
cwd: Working directory for Claude subprocess (default: current directory).
|
||||||
|
allowed_tools: List of allowed tools (e.g., ['Bash', 'Read', 'Write']).
|
||||||
|
system_prompt: Optional system prompt to append to Claude's default.
|
||||||
|
model: Model to use (e.g., 'sonnet', 'opus').
|
||||||
|
timeout: Timeout in seconds (default: self.default_timeout).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ClaudeResponse with success status, result text, and session_id.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
asyncio.TimeoutError: If command exceeds timeout (wrapped in ClaudeResponse.error).
|
||||||
|
"""
|
||||||
|
timeout = timeout or self.default_timeout
|
||||||
|
|
||||||
|
# Build command
|
||||||
|
cmd = self._build_command(
|
||||||
|
message=message,
|
||||||
|
session_id=session_id,
|
||||||
|
allowed_tools=allowed_tools,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log command invocation
|
||||||
|
logger.info(
|
||||||
|
f"Executing Claude CLI: {' '.join(cmd[:3])}... (session_id={session_id})",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"cwd": cwd,
|
||||||
|
"timeout": timeout,
|
||||||
|
"model": model
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare environment (CRITICAL: unset CLAUDECODE to allow nested sessions)
|
||||||
|
env = self._prepare_environment()
|
||||||
|
|
||||||
|
# Execute subprocess
|
||||||
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
*cmd,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=cwd,
|
||||||
|
env=env
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for completion with timeout
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(
|
||||||
|
process.communicate(),
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Kill process on timeout
|
||||||
|
process.kill()
|
||||||
|
await process.wait()
|
||||||
|
logger.error(f"Claude CLI timeout after {timeout}s")
|
||||||
|
return ClaudeResponse(
|
||||||
|
success=False,
|
||||||
|
result="",
|
||||||
|
error=f"Command timed out after {timeout} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check return code
|
||||||
|
if process.returncode != 0:
|
||||||
|
error_msg = stderr.decode('utf-8').strip()
|
||||||
|
logger.error(f"Claude CLI failed with code {process.returncode}: {error_msg}")
|
||||||
|
return ClaudeResponse(
|
||||||
|
success=False,
|
||||||
|
result="",
|
||||||
|
error=f"Process exited with code {process.returncode}: {error_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse JSON response
|
||||||
|
return self._parse_response(stdout.decode('utf-8'))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Unexpected error running Claude CLI: {e}")
|
||||||
|
return ClaudeResponse(
|
||||||
|
success=False,
|
||||||
|
result="",
|
||||||
|
error=f"Unexpected error: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_command(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
allowed_tools: Optional[List[str]] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
model: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""Build Claude CLI command with all necessary flags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: User message.
|
||||||
|
session_id: Optional session ID for resumption.
|
||||||
|
allowed_tools: List of allowed tools.
|
||||||
|
system_prompt: Optional system prompt.
|
||||||
|
model: Optional model selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of command arguments for subprocess execution.
|
||||||
|
"""
|
||||||
|
cmd = [
|
||||||
|
"claude",
|
||||||
|
"-p", message,
|
||||||
|
"--output-format", "json",
|
||||||
|
"--permission-mode", "bypassPermissions"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add session resumption flag
|
||||||
|
if session_id:
|
||||||
|
cmd.extend(["--resume", session_id])
|
||||||
|
|
||||||
|
# Add model selection
|
||||||
|
if model:
|
||||||
|
cmd.extend(["--model", model])
|
||||||
|
|
||||||
|
# Add system prompt
|
||||||
|
if system_prompt:
|
||||||
|
cmd.extend(["--system-prompt", system_prompt])
|
||||||
|
|
||||||
|
# Add tool restrictions
|
||||||
|
if allowed_tools:
|
||||||
|
cmd.extend(["--allowed-tools", ",".join(allowed_tools)])
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
def _prepare_environment(self) -> dict:
|
||||||
|
"""Prepare subprocess environment variables.
|
||||||
|
|
||||||
|
CRITICAL: Must unset CLAUDECODE environment variable to allow nested sessions.
|
||||||
|
See VALIDATION_RESULTS.md for details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Environment dictionary for subprocess.
|
||||||
|
"""
|
||||||
|
env = os.environ.copy()
|
||||||
|
|
||||||
|
# CRITICAL: Unset CLAUDECODE to allow nested Claude sessions
|
||||||
|
env.pop('CLAUDECODE', None)
|
||||||
|
|
||||||
|
# Set OAuth token if provided
|
||||||
|
if self.oauth_token:
|
||||||
|
env['CLAUDE_CODE_OAUTH_TOKEN'] = self.oauth_token
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
def _parse_response(self, stdout: str) -> ClaudeResponse:
|
||||||
|
"""Parse JSON response from Claude CLI.
|
||||||
|
|
||||||
|
Expected JSON structure (from VALIDATION_RESULTS.md):
|
||||||
|
{
|
||||||
|
"type": "result",
|
||||||
|
"subtype": "success" or error type,
|
||||||
|
"is_error": boolean,
|
||||||
|
"result": actual response text,
|
||||||
|
"session_id": UUID for session resumption,
|
||||||
|
"total_cost_usd": cost tracking,
|
||||||
|
"duration_ms": execution time,
|
||||||
|
"permission_denials": array (should be empty)
|
||||||
|
}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stdout: Raw stdout from Claude CLI.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ClaudeResponse with parsed data.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = json.loads(stdout)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Failed to parse JSON response: {e}")
|
||||||
|
logger.debug(f"Raw output: {stdout[:500]}...")
|
||||||
|
return ClaudeResponse(
|
||||||
|
success=False,
|
||||||
|
result="",
|
||||||
|
error=f"Malformed JSON response: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for errors (CRITICAL: use snake_case fields)
|
||||||
|
is_error = data.get("is_error", False)
|
||||||
|
|
||||||
|
if is_error:
|
||||||
|
error_msg = data.get("result", "Unknown error")
|
||||||
|
logger.error(f"Claude CLI returned error: {error_msg}")
|
||||||
|
return ClaudeResponse(
|
||||||
|
success=False,
|
||||||
|
result="",
|
||||||
|
error=error_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for permission denials
|
||||||
|
permission_denials = data.get("permission_denials", [])
|
||||||
|
if permission_denials:
|
||||||
|
logger.warning(f"Permission denials detected: {permission_denials}")
|
||||||
|
return ClaudeResponse(
|
||||||
|
success=False,
|
||||||
|
result="",
|
||||||
|
error=f"Permission denied: {permission_denials}",
|
||||||
|
permission_denials=permission_denials
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract response data (CRITICAL: use snake_case: session_id not sessionId)
|
||||||
|
return ClaudeResponse(
|
||||||
|
success=True,
|
||||||
|
result=data.get("result", ""),
|
||||||
|
session_id=data.get("session_id"),
|
||||||
|
cost=data.get("total_cost_usd"),
|
||||||
|
duration_ms=data.get("duration_ms"),
|
||||||
|
permission_denials=permission_denials
|
||||||
|
)
|
||||||
56
claude_coordinator/config.py
Normal file
56
claude_coordinator/config.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
"""Configuration management for Claude Discord Coordinator.
|
||||||
|
|
||||||
|
Loads and validates YAML configuration files with support for environment
|
||||||
|
variable substitution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration manager for bot settings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config_path: Path to the YAML configuration file.
|
||||||
|
data: Parsed configuration data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: Path):
|
||||||
|
"""Initialize configuration from YAML file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to the configuration file.
|
||||||
|
"""
|
||||||
|
self.config_path = config_path
|
||||||
|
self.data: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def load(self) -> None:
|
||||||
|
"""Load configuration from YAML file.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If config file does not exist.
|
||||||
|
yaml.YAMLError: If config file is invalid YAML.
|
||||||
|
"""
|
||||||
|
with open(self.config_path, "r") as f:
|
||||||
|
self.data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""Get configuration value by key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Configuration key (supports dot notation).
|
||||||
|
default: Default value if key not found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configuration value or default.
|
||||||
|
"""
|
||||||
|
keys = key.split(".")
|
||||||
|
value = self.data
|
||||||
|
for k in keys:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = value.get(k)
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
return value if value is not None else default
|
||||||
73
claude_coordinator/response_formatter.py
Normal file
73
claude_coordinator/response_formatter.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
"""Discord message formatting for Claude CLI output.
|
||||||
|
|
||||||
|
Handles splitting long responses, code block formatting, and Discord-specific
|
||||||
|
message constraints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormatter:
|
||||||
|
"""Formats Claude CLI output for Discord messages.
|
||||||
|
|
||||||
|
Handles Discord message length limits, code block formatting, and
|
||||||
|
chunking of long responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MAX_MESSAGE_LENGTH = 2000
|
||||||
|
MAX_CODE_BLOCK_LENGTH = 1990 # Account for markdown syntax
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_code_block(content: str, language: str = "") -> str:
|
||||||
|
"""Wrap content in Discord code block formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Text to format as code.
|
||||||
|
language: Optional syntax highlighting language.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted code block string.
|
||||||
|
"""
|
||||||
|
return f"```{language}\n{content}\n```"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def chunk_response(text: str, max_length: int = MAX_MESSAGE_LENGTH) -> List[str]:
|
||||||
|
"""Split long text into Discord-safe chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to split.
|
||||||
|
max_length: Maximum characters per chunk.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of text chunks under max_length.
|
||||||
|
"""
|
||||||
|
if len(text) <= max_length:
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
current_chunk = ""
|
||||||
|
|
||||||
|
for line in text.split("\n"):
|
||||||
|
if len(current_chunk) + len(line) + 1 <= max_length:
|
||||||
|
current_chunk += line + "\n"
|
||||||
|
else:
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk.rstrip())
|
||||||
|
current_chunk = line + "\n"
|
||||||
|
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk.rstrip())
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_error(error_message: str) -> str:
|
||||||
|
"""Format error message for Discord display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_message: Error text to format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted error message.
|
||||||
|
"""
|
||||||
|
return f":warning: **Error:**\n```\n{error_message}\n```"
|
||||||
486
claude_coordinator/session_manager.py
Normal file
486
claude_coordinator/session_manager.py
Normal file
@ -0,0 +1,486 @@
|
|||||||
|
"""
|
||||||
|
Session manager for Claude Discord Coordinator.
|
||||||
|
|
||||||
|
Manages per-channel session ID persistence using SQLite. Each Discord channel
|
||||||
|
maps to a Claude Code session, with metadata tracking for activity and usage.
|
||||||
|
|
||||||
|
Author: Claude Code
|
||||||
|
Created: 2026-02-13
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManager:
|
||||||
|
"""
|
||||||
|
Async SQLite-backed session manager for Discord channel to Claude session mapping.
|
||||||
|
|
||||||
|
Stores persistent session data including:
|
||||||
|
- channel_id: Discord channel identifier
|
||||||
|
- session_id: Claude Code session identifier
|
||||||
|
- project_name: Associated project name
|
||||||
|
- created_at: Session creation timestamp
|
||||||
|
- last_active: Last message timestamp
|
||||||
|
- message_count: Total messages in session
|
||||||
|
|
||||||
|
Thread-safe via asyncio.Lock for concurrent access protection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize session manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database file. Defaults to ~/.claude-coordinator/sessions.db
|
||||||
|
"""
|
||||||
|
if db_path is None:
|
||||||
|
db_dir = Path.home() / ".claude-coordinator"
|
||||||
|
db_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
db_path = str(db_dir / "sessions.db")
|
||||||
|
|
||||||
|
self.db_path = db_path
|
||||||
|
self._db: Optional[aiosqlite.Connection] = None
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
logger.info(f"SessionManager initialized with database: {self.db_path}")
|
||||||
|
|
||||||
|
async def __aenter__(self) -> 'SessionManager':
|
||||||
|
"""Async context manager entry - initialize database connection."""
|
||||||
|
await self._initialize_db()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Async context manager exit - close database connection."""
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
async def _initialize_db(self) -> None:
|
||||||
|
"""Create database connection and initialize schema."""
|
||||||
|
try:
|
||||||
|
self._db = await aiosqlite.connect(self.db_path)
|
||||||
|
# Enable foreign keys and WAL mode for better concurrency
|
||||||
|
await self._db.execute("PRAGMA foreign_keys = ON")
|
||||||
|
await self._db.execute("PRAGMA journal_mode = WAL")
|
||||||
|
|
||||||
|
# Create sessions table
|
||||||
|
await self._db.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
channel_id TEXT PRIMARY KEY,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
project_name TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_active TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
message_count INTEGER DEFAULT 0
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create indexes for common queries
|
||||||
|
await self._db.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_last_active
|
||||||
|
ON sessions(last_active)
|
||||||
|
""")
|
||||||
|
|
||||||
|
await self._db.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_project_name
|
||||||
|
ON sessions(project_name)
|
||||||
|
""")
|
||||||
|
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("Database schema initialized successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize database: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close database connection and clean up resources."""
|
||||||
|
if self._db:
|
||||||
|
try:
|
||||||
|
await self._db.close()
|
||||||
|
logger.info("Database connection closed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing database: {e}")
|
||||||
|
finally:
|
||||||
|
self._db = None
|
||||||
|
|
||||||
|
async def get_session(self, channel_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Retrieve session data for a Discord channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: Discord channel identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with session data or None if no session exists:
|
||||||
|
{
|
||||||
|
'channel_id': str,
|
||||||
|
'session_id': str,
|
||||||
|
'project_name': str or None,
|
||||||
|
'created_at': str (ISO format),
|
||||||
|
'last_active': str (ISO format),
|
||||||
|
'message_count': int
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
try:
|
||||||
|
async with self._db.execute(
|
||||||
|
"""
|
||||||
|
SELECT channel_id, session_id, project_name,
|
||||||
|
created_at, last_active, message_count
|
||||||
|
FROM sessions
|
||||||
|
WHERE channel_id = ?
|
||||||
|
""",
|
||||||
|
(channel_id,)
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
logger.debug(f"No session found for channel: {channel_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
session_data = {
|
||||||
|
'channel_id': row[0],
|
||||||
|
'session_id': row[1],
|
||||||
|
'project_name': row[2],
|
||||||
|
'created_at': row[3],
|
||||||
|
'last_active': row[4],
|
||||||
|
'message_count': row[5]
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"Retrieved session for channel {channel_id}: {session_data['session_id']}")
|
||||||
|
return session_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving session for channel {channel_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def save_session(
|
||||||
|
self,
|
||||||
|
channel_id: str,
|
||||||
|
session_id: str,
|
||||||
|
project_name: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save or update a session for a Discord channel.
|
||||||
|
|
||||||
|
If the channel already has a session, updates the session_id and project_name.
|
||||||
|
Otherwise, creates a new session record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: Discord channel identifier
|
||||||
|
session_id: Claude Code session identifier
|
||||||
|
project_name: Optional project name for the session
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
try:
|
||||||
|
# Check if session exists
|
||||||
|
existing = await self._get_session_unlocked(channel_id)
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
# Update existing session
|
||||||
|
await self._db.execute(
|
||||||
|
"""
|
||||||
|
UPDATE sessions
|
||||||
|
SET session_id = ?,
|
||||||
|
project_name = ?,
|
||||||
|
last_active = CURRENT_TIMESTAMP
|
||||||
|
WHERE channel_id = ?
|
||||||
|
""",
|
||||||
|
(session_id, project_name, channel_id)
|
||||||
|
)
|
||||||
|
logger.info(f"Updated session for channel {channel_id}: {session_id}")
|
||||||
|
else:
|
||||||
|
# Insert new session
|
||||||
|
await self._db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sessions
|
||||||
|
(channel_id, session_id, project_name, created_at, last_active, message_count)
|
||||||
|
VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 0)
|
||||||
|
""",
|
||||||
|
(channel_id, session_id, project_name)
|
||||||
|
)
|
||||||
|
logger.info(f"Created new session for channel {channel_id}: {session_id}")
|
||||||
|
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving session for channel {channel_id}: {e}")
|
||||||
|
await self._db.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def reset_session(self, channel_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a session for a Discord channel.
|
||||||
|
|
||||||
|
Used for /reset commands to start fresh conversations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: Discord channel identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a session was deleted, False if no session existed
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
try:
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"DELETE FROM sessions WHERE channel_id = ?",
|
||||||
|
(channel_id,)
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
deleted = cursor.rowcount > 0
|
||||||
|
|
||||||
|
if deleted:
|
||||||
|
logger.info(f"Reset session for channel: {channel_id}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"No session to reset for channel: {channel_id}")
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error resetting session for channel {channel_id}: {e}")
|
||||||
|
await self._db.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def list_sessions(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all active sessions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of session dictionaries, ordered by last_active (most recent first):
|
||||||
|
[
|
||||||
|
{
|
||||||
|
'channel_id': str,
|
||||||
|
'session_id': str,
|
||||||
|
'project_name': str or None,
|
||||||
|
'created_at': str,
|
||||||
|
'last_active': str,
|
||||||
|
'message_count': int
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
try:
|
||||||
|
async with self._db.execute(
|
||||||
|
"""
|
||||||
|
SELECT channel_id, session_id, project_name,
|
||||||
|
created_at, last_active, message_count
|
||||||
|
FROM sessions
|
||||||
|
ORDER BY last_active DESC
|
||||||
|
"""
|
||||||
|
) as cursor:
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
|
sessions = [
|
||||||
|
{
|
||||||
|
'channel_id': row[0],
|
||||||
|
'session_id': row[1],
|
||||||
|
'project_name': row[2],
|
||||||
|
'created_at': row[3],
|
||||||
|
'last_active': row[4],
|
||||||
|
'message_count': row[5]
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug(f"Listed {len(sessions)} active sessions")
|
||||||
|
return sessions
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing sessions: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update_activity(self, channel_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Update last_active timestamp and increment message_count for a channel.
|
||||||
|
|
||||||
|
Should be called every time a message is processed for the channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: Discord channel identifier
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
try:
|
||||||
|
await self._db.execute(
|
||||||
|
"""
|
||||||
|
UPDATE sessions
|
||||||
|
SET last_active = CURRENT_TIMESTAMP,
|
||||||
|
message_count = message_count + 1
|
||||||
|
WHERE channel_id = ?
|
||||||
|
""",
|
||||||
|
(channel_id,)
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
logger.debug(f"Updated activity for channel: {channel_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating activity for channel {channel_id}: {e}")
|
||||||
|
await self._db.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get session statistics for monitoring and debugging.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with statistics:
|
||||||
|
{
|
||||||
|
'total_sessions': int,
|
||||||
|
'total_messages': int,
|
||||||
|
'active_projects': int,
|
||||||
|
'most_active_channel': str or None,
|
||||||
|
'oldest_session': str or None,
|
||||||
|
'newest_session': str or None
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
try:
|
||||||
|
stats = {}
|
||||||
|
|
||||||
|
# Total sessions
|
||||||
|
async with self._db.execute("SELECT COUNT(*) FROM sessions") as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
stats['total_sessions'] = row[0] if row else 0
|
||||||
|
|
||||||
|
# Total messages
|
||||||
|
async with self._db.execute("SELECT SUM(message_count) FROM sessions") as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
stats['total_messages'] = row[0] if row and row[0] else 0
|
||||||
|
|
||||||
|
# Active projects (distinct non-null project names)
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT COUNT(DISTINCT project_name) FROM sessions WHERE project_name IS NOT NULL"
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
stats['active_projects'] = row[0] if row else 0
|
||||||
|
|
||||||
|
# Most active channel
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT channel_id FROM sessions ORDER BY message_count DESC LIMIT 1"
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
stats['most_active_channel'] = row[0] if row else None
|
||||||
|
|
||||||
|
# Oldest session
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT channel_id FROM sessions ORDER BY created_at ASC LIMIT 1"
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
stats['oldest_session'] = row[0] if row else None
|
||||||
|
|
||||||
|
# Newest session
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT channel_id FROM sessions ORDER BY created_at DESC LIMIT 1"
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
stats['newest_session'] = row[0] if row else None
|
||||||
|
|
||||||
|
logger.debug(f"Retrieved session statistics: {stats}")
|
||||||
|
return stats
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving session statistics: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _get_session_unlocked(self, channel_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Internal method to get session without acquiring lock.
|
||||||
|
|
||||||
|
Used when the lock is already held to avoid deadlock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: Discord channel identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Session dictionary or None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with self._db.execute(
|
||||||
|
"""
|
||||||
|
SELECT channel_id, session_id, project_name,
|
||||||
|
created_at, last_active, message_count
|
||||||
|
FROM sessions
|
||||||
|
WHERE channel_id = ?
|
||||||
|
""",
|
||||||
|
(channel_id,)
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
'channel_id': row[0],
|
||||||
|
'session_id': row[1],
|
||||||
|
'project_name': row[2],
|
||||||
|
'created_at': row[3],
|
||||||
|
'last_active': row[4],
|
||||||
|
'message_count': row[5]
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in _get_session_unlocked for channel {channel_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
async def main():
|
||||||
|
"""Example demonstrating SessionManager usage."""
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use context manager for automatic cleanup
|
||||||
|
async with SessionManager() as manager:
|
||||||
|
# Create a session
|
||||||
|
await manager.save_session(
|
||||||
|
channel_id="123456789",
|
||||||
|
session_id="sess_abc123",
|
||||||
|
project_name="major-domo"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retrieve the session
|
||||||
|
session = await manager.get_session("123456789")
|
||||||
|
print(f"Retrieved session: {session}")
|
||||||
|
|
||||||
|
# Update activity
|
||||||
|
await manager.update_activity("123456789")
|
||||||
|
await manager.update_activity("123456789")
|
||||||
|
|
||||||
|
# Create another session
|
||||||
|
await manager.save_session(
|
||||||
|
channel_id="987654321",
|
||||||
|
session_id="sess_xyz789",
|
||||||
|
project_name="paper-dynasty"
|
||||||
|
)
|
||||||
|
|
||||||
|
# List all sessions
|
||||||
|
sessions = await manager.list_sessions()
|
||||||
|
print(f"\nAll sessions ({len(sessions)}):")
|
||||||
|
for s in sessions:
|
||||||
|
print(f" - {s['channel_id']}: {s['session_id']} ({s['message_count']} messages)")
|
||||||
|
|
||||||
|
# Get statistics
|
||||||
|
stats = await manager.get_stats()
|
||||||
|
print(f"\nStatistics: {stats}")
|
||||||
|
|
||||||
|
# Reset a session
|
||||||
|
deleted = await manager.reset_session("123456789")
|
||||||
|
print(f"\nSession reset: {deleted}")
|
||||||
|
|
||||||
|
# Verify deletion
|
||||||
|
session = await manager.get_session("123456789")
|
||||||
|
print(f"Session after reset: {session}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
86
examples/basic_usage.py
Normal file
86
examples/basic_usage.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Basic usage example for ClaudeRunner.
|
||||||
|
|
||||||
|
Demonstrates:
|
||||||
|
1. Creating a new Claude session
|
||||||
|
2. Resuming an existing session with context preservation
|
||||||
|
3. Error handling and timeout configuration
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from claude_coordinator.claude_runner import ClaudeRunner
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Demonstrate basic ClaudeRunner usage."""
|
||||||
|
|
||||||
|
# Initialize runner with 2-minute timeout
|
||||||
|
runner = ClaudeRunner(default_timeout=120)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("ClaudeRunner Basic Usage Example")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Example 1: Create new session
|
||||||
|
print("\n1. Creating new session...")
|
||||||
|
response1 = await runner.run(
|
||||||
|
message="Hello! Please respond with 'Hi there!' and nothing else.",
|
||||||
|
model="sonnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
if response1.success:
|
||||||
|
print(f" ✓ Success!")
|
||||||
|
print(f" Response: {response1.result}")
|
||||||
|
print(f" Session ID: {response1.session_id}")
|
||||||
|
print(f" Cost: ${response1.cost:.4f}")
|
||||||
|
print(f" Duration: {response1.duration_ms}ms")
|
||||||
|
else:
|
||||||
|
print(f" ✗ Error: {response1.error}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Example 2: Resume session with context
|
||||||
|
print("\n2. Resuming session with context...")
|
||||||
|
session_id = response1.session_id
|
||||||
|
|
||||||
|
response2 = await runner.run(
|
||||||
|
message="What did I just ask you to say?",
|
||||||
|
session_id=session_id,
|
||||||
|
model="sonnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
if response2.success:
|
||||||
|
print(f" ✓ Success!")
|
||||||
|
print(f" Response: {response2.result}")
|
||||||
|
print(f" Session ID preserved: {response2.session_id == session_id}")
|
||||||
|
print(f" Cost: ${response2.cost:.4f}")
|
||||||
|
else:
|
||||||
|
print(f" ✗ Error: {response2.error}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Example 3: Using allowed tools and working directory
|
||||||
|
print("\n3. Using tool restrictions and working directory...")
|
||||||
|
response3 = await runner.run(
|
||||||
|
message="List the files in the current directory",
|
||||||
|
session_id=session_id,
|
||||||
|
cwd="/opt/projects/claude-coordinator",
|
||||||
|
allowed_tools=["Bash", "Read", "Glob"], # No Write or Edit
|
||||||
|
model="sonnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
if response3.success:
|
||||||
|
print(f" ✓ Success!")
|
||||||
|
print(f" Response: {response3.result[:200]}...")
|
||||||
|
print(f" Cost: ${response3.cost:.4f}")
|
||||||
|
else:
|
||||||
|
print(f" ✗ Error: {response3.error}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Total cost for this session: ${:.4f}".format(
|
||||||
|
(response1.cost or 0) + (response2.cost or 0) + (response3.cost or 0)
|
||||||
|
))
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
6
main.py
Normal file
6
main.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
def main():
|
||||||
|
print("Hello from claude-coordinator!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
17
pyproject.toml
Normal file
17
pyproject.toml
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
[project]
|
||||||
|
name = "claude-coordinator"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"aiosqlite>=0.22.1",
|
||||||
|
"discord-py>=2.6.4",
|
||||||
|
"pyyaml>=6.0.3",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"pytest>=9.0.2",
|
||||||
|
"pytest-asyncio>=1.3.0",
|
||||||
|
]
|
||||||
5
pytest.ini
Normal file
5
pytest.ini
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
[pytest]
|
||||||
|
asyncio_mode = auto
|
||||||
|
asyncio_default_fixture_loop_scope = function
|
||||||
|
markers =
|
||||||
|
slow: marks tests as slow (deselect with '-m "not slow"')
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
10
tests/conftest.py
Normal file
10
tests/conftest.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
"""Pytest configuration for claude-coordinator tests."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
"""Register custom markers."""
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers", "integration: mark test as integration test requiring Claude CLI authentication"
|
||||||
|
)
|
||||||
389
tests/test_claude_runner.py
Normal file
389
tests/test_claude_runner.py
Normal file
@ -0,0 +1,389 @@
|
|||||||
|
"""Comprehensive tests for ClaudeRunner subprocess wrapper.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- New session creation
|
||||||
|
- Session resumption with context preservation
|
||||||
|
- Timeout handling
|
||||||
|
- JSON parsing (including edge cases)
|
||||||
|
- Error handling (process failures, malformed output, permission denials)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from claude_coordinator.claude_runner import ClaudeRunner, ClaudeResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TestClaudeRunner:
|
||||||
|
"""Test suite for ClaudeRunner class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def runner(self):
|
||||||
|
"""Create a ClaudeRunner instance for testing."""
|
||||||
|
return ClaudeRunner(default_timeout=30)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_process(self):
|
||||||
|
"""Create a mock subprocess for testing."""
|
||||||
|
process = AsyncMock()
|
||||||
|
process.returncode = 0
|
||||||
|
return process
|
||||||
|
|
||||||
|
def create_mock_response(
|
||||||
|
self,
|
||||||
|
result="Test response",
|
||||||
|
session_id="test-session-123",
|
||||||
|
is_error=False,
|
||||||
|
cost=0.01
|
||||||
|
):
|
||||||
|
"""Helper to create mock JSON response."""
|
||||||
|
return json.dumps({
|
||||||
|
"type": "result",
|
||||||
|
"subtype": "success" if not is_error else "error",
|
||||||
|
"is_error": is_error,
|
||||||
|
"result": result,
|
||||||
|
"session_id": session_id,
|
||||||
|
"total_cost_usd": cost,
|
||||||
|
"duration_ms": 2000,
|
||||||
|
"permission_denials": []
|
||||||
|
}).encode('utf-8')
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_session_creation(self, runner, mock_process):
|
||||||
|
"""Test creating a new Claude session without session_id.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
- Command is built correctly without --resume flag
|
||||||
|
- JSON response is parsed successfully
|
||||||
|
- Session ID is extracted from response
|
||||||
|
"""
|
||||||
|
stdout = self.create_mock_response(
|
||||||
|
result="Hello! How can I help?",
|
||||||
|
session_id="new-session-uuid-123"
|
||||||
|
)
|
||||||
|
mock_process.communicate.return_value = (stdout, b"")
|
||||||
|
|
||||||
|
with patch('asyncio.create_subprocess_exec', return_value=mock_process):
|
||||||
|
response = await runner.run(
|
||||||
|
message="Hello Claude",
|
||||||
|
session_id=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.success is True
|
||||||
|
assert response.result == "Hello! How can I help?"
|
||||||
|
assert response.session_id == "new-session-uuid-123"
|
||||||
|
assert response.error is None
|
||||||
|
assert response.cost == 0.01
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_resumption(self, runner, mock_process):
|
||||||
|
"""Test resuming an existing session with session_id.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
- --resume flag is included in command
|
||||||
|
- Session ID is passed correctly
|
||||||
|
- Response maintains session context
|
||||||
|
"""
|
||||||
|
existing_session_id = "existing-session-456"
|
||||||
|
stdout = self.create_mock_response(
|
||||||
|
result="You asked about Python before.",
|
||||||
|
session_id=existing_session_id
|
||||||
|
)
|
||||||
|
mock_process.communicate.return_value = (stdout, b"")
|
||||||
|
|
||||||
|
with patch('asyncio.create_subprocess_exec', return_value=mock_process) as mock_exec:
|
||||||
|
response = await runner.run(
|
||||||
|
message="What was I asking about?",
|
||||||
|
session_id=existing_session_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify --resume flag was added
|
||||||
|
call_args = mock_exec.call_args
|
||||||
|
cmd = call_args[0]
|
||||||
|
assert "--resume" in cmd
|
||||||
|
assert existing_session_id in cmd
|
||||||
|
|
||||||
|
assert response.success is True
|
||||||
|
assert response.session_id == existing_session_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_handling(self, runner, mock_process):
|
||||||
|
"""Test subprocess timeout with asyncio.TimeoutError.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
- Process is killed on timeout
|
||||||
|
- Error response is returned
|
||||||
|
- Timeout duration is respected
|
||||||
|
"""
|
||||||
|
async def slow_communicate():
|
||||||
|
"""Simulate a slow response that times out."""
|
||||||
|
await asyncio.sleep(100) # Longer than timeout
|
||||||
|
return (b"", b"")
|
||||||
|
|
||||||
|
mock_process.communicate = slow_communicate
|
||||||
|
|
||||||
|
with patch('asyncio.create_subprocess_exec', return_value=mock_process):
|
||||||
|
response = await runner.run(
|
||||||
|
message="Slow command",
|
||||||
|
timeout=1 # 1 second timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.success is False
|
||||||
|
assert "timed out" in response.error.lower()
|
||||||
|
assert mock_process.kill.called
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_malformed_json_handling(self, runner, mock_process):
|
||||||
|
"""Test handling of malformed JSON output.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
- JSON parse errors are caught
|
||||||
|
- Error response is returned with details
|
||||||
|
- Raw output is logged for debugging
|
||||||
|
"""
|
||||||
|
stdout = b"This is not valid JSON {{{"
|
||||||
|
mock_process.communicate.return_value = (stdout, b"")
|
||||||
|
|
||||||
|
with patch('asyncio.create_subprocess_exec', return_value=mock_process):
|
||||||
|
response = await runner.run(message="Test")
|
||||||
|
|
||||||
|
assert response.success is False
|
||||||
|
assert "Malformed JSON" in response.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_error_handling(self, runner, mock_process):
|
||||||
|
"""Test handling of non-zero exit codes.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
- Non-zero return codes are detected
|
||||||
|
- stderr is captured and returned
|
||||||
|
- Error response is generated
|
||||||
|
"""
|
||||||
|
mock_process.returncode = 1
|
||||||
|
stderr = b"Claude CLI error: Invalid token"
|
||||||
|
mock_process.communicate.return_value = (b"", stderr)
|
||||||
|
|
||||||
|
with patch('asyncio.create_subprocess_exec', return_value=mock_process):
|
||||||
|
response = await runner.run(message="Test")
|
||||||
|
|
||||||
|
assert response.success is False
|
||||||
|
assert "exited with code 1" in response.error
|
||||||
|
assert "Invalid token" in response.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_claude_error_response(self, runner, mock_process):
|
||||||
|
"""Test handling of Claude API errors in JSON response.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
- is_error flag is detected
|
||||||
|
- Error message is extracted from result field
|
||||||
|
- Error response is returned
|
||||||
|
"""
|
||||||
|
stdout = self.create_mock_response(
|
||||||
|
result="API rate limit exceeded",
|
||||||
|
is_error=True
|
||||||
|
)
|
||||||
|
mock_process.communicate.return_value = (stdout, b"")
|
||||||
|
|
||||||
|
with patch('asyncio.create_subprocess_exec', return_value=mock_process):
|
||||||
|
response = await runner.run(message="Test")
|
||||||
|
|
||||||
|
assert response.success is False
|
||||||
|
assert "API rate limit exceeded" in response.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_permission_denial_handling(self, runner, mock_process):
|
||||||
|
"""Test detection of permission denials.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
- permission_denials array is checked
|
||||||
|
- Error response is returned if permissions denied
|
||||||
|
- Denial details are included in error
|
||||||
|
"""
|
||||||
|
response_data = json.dumps({
|
||||||
|
"type": "result",
|
||||||
|
"is_error": False,
|
||||||
|
"result": "Cannot execute",
|
||||||
|
"session_id": "test-123",
|
||||||
|
"permission_denials": [{"tool": "Write", "reason": "Not allowed"}]
|
||||||
|
}).encode('utf-8')
|
||||||
|
|
||||||
|
mock_process.communicate.return_value = (response_data, b"")
|
||||||
|
|
||||||
|
with patch('asyncio.create_subprocess_exec', return_value=mock_process):
|
||||||
|
response = await runner.run(message="Test")
|
||||||
|
|
||||||
|
assert response.success is False
|
||||||
|
assert "Permission denied" in response.error
|
||||||
|
assert response.permission_denials is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_command_building_with_all_options(self, runner, mock_process):
|
||||||
|
"""Test command building with all optional parameters.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
- All flags are included in command
|
||||||
|
- Values are passed correctly
|
||||||
|
- Command structure is valid
|
||||||
|
"""
|
||||||
|
stdout = self.create_mock_response()
|
||||||
|
mock_process.communicate.return_value = (stdout, b"")
|
||||||
|
|
||||||
|
with patch('asyncio.create_subprocess_exec', return_value=mock_process) as mock_exec:
|
||||||
|
await runner.run(
|
||||||
|
message="Test message",
|
||||||
|
session_id="test-session",
|
||||||
|
allowed_tools=["Bash", "Read", "Write"],
|
||||||
|
system_prompt="Custom system prompt",
|
||||||
|
model="sonnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract command from call
|
||||||
|
cmd = mock_exec.call_args[0]
|
||||||
|
|
||||||
|
# Verify all flags present
|
||||||
|
assert "claude" in cmd
|
||||||
|
assert "-p" in cmd
|
||||||
|
assert "Test message" in cmd
|
||||||
|
assert "--output-format" in cmd
|
||||||
|
assert "json" in cmd
|
||||||
|
assert "--permission-mode" in cmd
|
||||||
|
assert "bypassPermissions" in cmd
|
||||||
|
assert "--resume" in cmd
|
||||||
|
assert "test-session" in cmd
|
||||||
|
assert "--allowed-tools" in cmd
|
||||||
|
assert "Bash,Read,Write" in cmd
|
||||||
|
assert "--system-prompt" in cmd
|
||||||
|
assert "Custom system prompt" in cmd
|
||||||
|
assert "--model" in cmd
|
||||||
|
assert "sonnet" in cmd
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_environment_preparation(self, runner, mock_process):
|
||||||
|
"""Test environment variable preparation.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
- CLAUDECODE is unset to allow nested sessions
|
||||||
|
- CLAUDE_CODE_OAUTH_TOKEN is set if provided
|
||||||
|
- Environment is passed to subprocess
|
||||||
|
"""
|
||||||
|
stdout = self.create_mock_response()
|
||||||
|
mock_process.communicate.return_value = (stdout, b"")
|
||||||
|
|
||||||
|
runner_with_token = ClaudeRunner(oauth_token="test-oauth-token")
|
||||||
|
|
||||||
|
with patch('asyncio.create_subprocess_exec', return_value=mock_process) as mock_exec:
|
||||||
|
with patch.dict('os.environ', {'CLAUDECODE': 'some-value'}):
|
||||||
|
await runner_with_token.run(message="Test")
|
||||||
|
|
||||||
|
# Check environment passed to subprocess
|
||||||
|
env = mock_exec.call_args[1]['env']
|
||||||
|
|
||||||
|
# CLAUDECODE should be removed
|
||||||
|
assert 'CLAUDECODE' not in env
|
||||||
|
|
||||||
|
# OAuth token should be set
|
||||||
|
assert env.get('CLAUDE_CODE_OAUTH_TOKEN') == 'test-oauth-token'
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cwd_parameter(self, runner, mock_process):
|
||||||
|
"""Test working directory parameter.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
- cwd is passed to subprocess
|
||||||
|
- Claude runs in the correct directory
|
||||||
|
"""
|
||||||
|
stdout = self.create_mock_response()
|
||||||
|
mock_process.communicate.return_value = (stdout, b"")
|
||||||
|
|
||||||
|
test_cwd = "/opt/projects/test-project"
|
||||||
|
|
||||||
|
with patch('asyncio.create_subprocess_exec', return_value=mock_process) as mock_exec:
|
||||||
|
await runner.run(
|
||||||
|
message="Test",
|
||||||
|
cwd=test_cwd
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify cwd was passed
|
||||||
|
assert mock_exec.call_args[1]['cwd'] == test_cwd
|
||||||
|
|
||||||
|
def test_parse_response_edge_cases(self, runner):
|
||||||
|
"""Test JSON parsing with various edge cases.
|
||||||
|
|
||||||
|
Verifies handling of:
|
||||||
|
- Empty result field
|
||||||
|
- Missing optional fields
|
||||||
|
- Unusual but valid JSON structures
|
||||||
|
"""
|
||||||
|
# Test with minimal valid response
|
||||||
|
minimal_json = json.dumps({
|
||||||
|
"is_error": False,
|
||||||
|
"result": ""
|
||||||
|
})
|
||||||
|
|
||||||
|
response = runner._parse_response(minimal_json)
|
||||||
|
assert response.success is True
|
||||||
|
assert response.result == ""
|
||||||
|
assert response.session_id is None
|
||||||
|
|
||||||
|
# Test with all optional fields present
|
||||||
|
complete_json = json.dumps({
|
||||||
|
"type": "result",
|
||||||
|
"subtype": "success",
|
||||||
|
"is_error": False,
|
||||||
|
"result": "Complete response",
|
||||||
|
"session_id": "uuid-123",
|
||||||
|
"total_cost_usd": 0.05,
|
||||||
|
"duration_ms": 5000,
|
||||||
|
"permission_denials": []
|
||||||
|
})
|
||||||
|
|
||||||
|
response = runner._parse_response(complete_json)
|
||||||
|
assert response.success is True
|
||||||
|
assert response.result == "Complete response"
|
||||||
|
assert response.session_id == "uuid-123"
|
||||||
|
assert response.cost == 0.05
|
||||||
|
assert response.duration_ms == 5000
|
||||||
|
|
||||||
|
|
||||||
|
# Integration test (requires actual Claude CLI installed)
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_claude_session():
|
||||||
|
"""Integration test with real Claude CLI (requires authentication).
|
||||||
|
|
||||||
|
This test is marked as integration and will be skipped unless
|
||||||
|
explicitly run with: pytest -m integration
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Real session creation works
|
||||||
|
- Session resumption preserves context
|
||||||
|
- JSON parsing works with real output
|
||||||
|
"""
|
||||||
|
runner = ClaudeRunner(default_timeout=60)
|
||||||
|
|
||||||
|
# Create new session
|
||||||
|
response1 = await runner.run(
|
||||||
|
message="Please respond with exactly: 'Integration test response'"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response1.success is True
|
||||||
|
assert "Integration test response" in response1.result
|
||||||
|
assert response1.session_id is not None
|
||||||
|
|
||||||
|
# Resume session
|
||||||
|
session_id = response1.session_id
|
||||||
|
response2 = await runner.run(
|
||||||
|
message="What did I just ask you to say?",
|
||||||
|
session_id=session_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response2.success is True
|
||||||
|
assert response2.session_id == session_id
|
||||||
|
assert "Integration test response" in response2.result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run tests with: python -m pytest tests/test_claude_runner.py -v
|
||||||
|
pytest.main(["-v", __file__])
|
||||||
628
tests/test_session_manager.py
Normal file
628
tests/test_session_manager.py
Normal file
@ -0,0 +1,628 @@
|
|||||||
|
"""
|
||||||
|
Test suite for SessionManager.
|
||||||
|
|
||||||
|
Tests database creation, CRUD operations, concurrent access, and edge cases.
|
||||||
|
|
||||||
|
Author: Claude Code
|
||||||
|
Created: 2026-02-13
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from claude_coordinator.session_manager import SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
# Configure logging for tests
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_db():
|
||||||
|
"""Create a temporary database file for testing."""
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.db') as f:
|
||||||
|
db_path = f.name
|
||||||
|
|
||||||
|
yield db_path
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
try:
|
||||||
|
os.unlink(db_path)
|
||||||
|
# Also remove WAL and SHM files if they exist
|
||||||
|
for ext in ['-wal', '-shm']:
|
||||||
|
wal_path = db_path + ext
|
||||||
|
if os.path.exists(wal_path):
|
||||||
|
os.unlink(wal_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not clean up test database: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def manager(temp_db):
|
||||||
|
"""Create a SessionManager instance with temporary database."""
|
||||||
|
mgr = SessionManager(db_path=temp_db)
|
||||||
|
await mgr._initialize_db()
|
||||||
|
yield mgr
|
||||||
|
await mgr.close()
|
||||||
|
|
||||||
|
|
||||||
|
# Configure pytest-asyncio
|
||||||
|
pytest_plugins = ('pytest_asyncio',)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionManagerInit:
|
||||||
|
"""Test database initialization and schema creation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_database_creation(self, temp_db):
|
||||||
|
"""
|
||||||
|
Test that database file is created and schema is initialized.
|
||||||
|
|
||||||
|
What: Verify database file creation and table schema
|
||||||
|
Why: Ensure proper initialization on first use
|
||||||
|
"""
|
||||||
|
async with SessionManager(db_path=temp_db) as manager:
|
||||||
|
assert manager._db is not None
|
||||||
|
assert Path(temp_db).exists()
|
||||||
|
|
||||||
|
# Verify sessions table exists
|
||||||
|
async with manager._db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='sessions'"
|
||||||
|
) as cursor:
|
||||||
|
result = await cursor.fetchone()
|
||||||
|
assert result is not None
|
||||||
|
assert result[0] == 'sessions'
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_indexes_created(self, temp_db):
|
||||||
|
"""
|
||||||
|
Test that database indexes are created for performance.
|
||||||
|
|
||||||
|
What: Check for idx_last_active and idx_project_name indexes
|
||||||
|
Why: Indexes are critical for query performance
|
||||||
|
"""
|
||||||
|
async with SessionManager(db_path=temp_db) as manager:
|
||||||
|
async with manager._db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='index'"
|
||||||
|
) as cursor:
|
||||||
|
indexes = [row[0] for row in await cursor.fetchall()]
|
||||||
|
|
||||||
|
assert 'idx_last_active' in indexes
|
||||||
|
assert 'idx_project_name' in indexes
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_default_db_path(self):
|
||||||
|
"""
|
||||||
|
Test that default database path is created in user home directory.
|
||||||
|
|
||||||
|
What: Verify ~/.claude-coordinator/sessions.db creation
|
||||||
|
Why: Users should get sensible defaults without configuration
|
||||||
|
"""
|
||||||
|
manager = SessionManager()
|
||||||
|
expected_dir = Path.home() / ".claude-coordinator"
|
||||||
|
expected_path = expected_dir / "sessions.db"
|
||||||
|
|
||||||
|
assert manager.db_path == str(expected_path)
|
||||||
|
|
||||||
|
# Cleanup if created
|
||||||
|
if expected_path.exists():
|
||||||
|
expected_path.unlink()
|
||||||
|
for ext in ['-wal', '-shm']:
|
||||||
|
wal_path = str(expected_path) + ext
|
||||||
|
if Path(wal_path).exists():
|
||||||
|
Path(wal_path).unlink()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionCRUD:
|
||||||
|
"""Test Create, Read, Update, Delete operations for sessions."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_new_session(self, manager):
|
||||||
|
"""
|
||||||
|
Test creating a new session.
|
||||||
|
|
||||||
|
What: Save a new session and verify it's stored correctly
|
||||||
|
Why: Session creation is core functionality
|
||||||
|
"""
|
||||||
|
await manager.save_session(
|
||||||
|
channel_id="123456",
|
||||||
|
session_id="sess_abc123",
|
||||||
|
project_name="major-domo"
|
||||||
|
)
|
||||||
|
|
||||||
|
session = await manager.get_session("123456")
|
||||||
|
assert session is not None
|
||||||
|
assert session['channel_id'] == "123456"
|
||||||
|
assert session['session_id'] == "sess_abc123"
|
||||||
|
assert session['project_name'] == "major-domo"
|
||||||
|
assert session['message_count'] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_session_without_project(self, manager):
|
||||||
|
"""
|
||||||
|
Test creating a session without a project name.
|
||||||
|
|
||||||
|
What: Save session with project_name=None
|
||||||
|
Why: Project name should be optional
|
||||||
|
"""
|
||||||
|
await manager.save_session(
|
||||||
|
channel_id="789012",
|
||||||
|
session_id="sess_xyz789"
|
||||||
|
)
|
||||||
|
|
||||||
|
session = await manager.get_session("789012")
|
||||||
|
assert session is not None
|
||||||
|
assert session['session_id'] == "sess_xyz789"
|
||||||
|
assert session['project_name'] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_existing_session(self, manager):
|
||||||
|
"""
|
||||||
|
Test updating an existing session with new session_id.
|
||||||
|
|
||||||
|
What: Save a session, then save again with different session_id
|
||||||
|
Why: Sessions need to be updateable when Claude sessions change
|
||||||
|
"""
|
||||||
|
# Create initial session
|
||||||
|
await manager.save_session(
|
||||||
|
channel_id="123456",
|
||||||
|
session_id="sess_old",
|
||||||
|
project_name="project-a"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update with new session_id
|
||||||
|
await manager.save_session(
|
||||||
|
channel_id="123456",
|
||||||
|
session_id="sess_new",
|
||||||
|
project_name="project-b"
|
||||||
|
)
|
||||||
|
|
||||||
|
session = await manager.get_session("123456")
|
||||||
|
assert session['session_id'] == "sess_new"
|
||||||
|
assert session['project_name'] == "project-b"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_nonexistent_session(self, manager):
|
||||||
|
"""
|
||||||
|
Test retrieving a session that doesn't exist.
|
||||||
|
|
||||||
|
What: Call get_session for channel with no session
|
||||||
|
Why: Should return None gracefully, not error
|
||||||
|
"""
|
||||||
|
session = await manager.get_session("nonexistent")
|
||||||
|
assert session is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_session(self, manager):
|
||||||
|
"""
|
||||||
|
Test deleting a session.
|
||||||
|
|
||||||
|
What: Create session, reset it, verify it's gone
|
||||||
|
Why: /reset command needs to clear sessions
|
||||||
|
"""
|
||||||
|
await manager.save_session(
|
||||||
|
channel_id="123456",
|
||||||
|
session_id="sess_abc123"
|
||||||
|
)
|
||||||
|
|
||||||
|
deleted = await manager.reset_session("123456")
|
||||||
|
assert deleted is True
|
||||||
|
|
||||||
|
session = await manager.get_session("123456")
|
||||||
|
assert session is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_nonexistent_session(self, manager):
|
||||||
|
"""
|
||||||
|
Test resetting a session that doesn't exist.
|
||||||
|
|
||||||
|
What: Call reset_session for channel with no session
|
||||||
|
Why: Should return False, not error
|
||||||
|
"""
|
||||||
|
deleted = await manager.reset_session("nonexistent")
|
||||||
|
assert deleted is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionActivity:
|
||||||
|
"""Test session activity tracking."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_activity(self, manager):
|
||||||
|
"""
|
||||||
|
Test updating last_active and incrementing message_count.
|
||||||
|
|
||||||
|
What: Create session, update activity multiple times, verify counters
|
||||||
|
Why: Activity tracking is essential for monitoring
|
||||||
|
"""
|
||||||
|
await manager.save_session(
|
||||||
|
channel_id="123456",
|
||||||
|
session_id="sess_abc123"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update activity 3 times
|
||||||
|
await manager.update_activity("123456")
|
||||||
|
await manager.update_activity("123456")
|
||||||
|
await manager.update_activity("123456")
|
||||||
|
|
||||||
|
session = await manager.get_session("123456")
|
||||||
|
assert session['message_count'] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_activity_updates_timestamp(self, manager):
|
||||||
|
"""
|
||||||
|
Test that update_activity changes last_active timestamp.
|
||||||
|
|
||||||
|
What: Create session, wait, update activity, check timestamp changed
|
||||||
|
Why: Timestamp tracking is needed for session management
|
||||||
|
"""
|
||||||
|
await manager.save_session(
|
||||||
|
channel_id="123456",
|
||||||
|
session_id="sess_abc123"
|
||||||
|
)
|
||||||
|
|
||||||
|
session1 = await manager.get_session("123456")
|
||||||
|
original_timestamp = session1['last_active']
|
||||||
|
|
||||||
|
# Small delay to ensure timestamp difference
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
await manager.update_activity("123456")
|
||||||
|
|
||||||
|
session2 = await manager.get_session("123456")
|
||||||
|
new_timestamp = session2['last_active']
|
||||||
|
|
||||||
|
assert new_timestamp >= original_timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionListing:
|
||||||
|
"""Test listing and querying multiple sessions."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_empty_sessions(self, manager):
|
||||||
|
"""
|
||||||
|
Test listing when no sessions exist.
|
||||||
|
|
||||||
|
What: Call list_sessions on empty database
|
||||||
|
Why: Should return empty list, not error
|
||||||
|
"""
|
||||||
|
sessions = await manager.list_sessions()
|
||||||
|
assert sessions == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_multiple_sessions(self, manager):
|
||||||
|
"""
|
||||||
|
Test listing multiple sessions.
|
||||||
|
|
||||||
|
What: Create 3 sessions, list them, verify order
|
||||||
|
Why: Bot needs to view all active sessions
|
||||||
|
"""
|
||||||
|
await manager.save_session("channel1", "sess1", "project-a")
|
||||||
|
await asyncio.sleep(0.01) # Small delay for timestamp ordering
|
||||||
|
await manager.save_session("channel2", "sess2", "project-b")
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
await manager.save_session("channel3", "sess3", "project-c")
|
||||||
|
|
||||||
|
sessions = await manager.list_sessions()
|
||||||
|
assert len(sessions) == 3
|
||||||
|
|
||||||
|
# Should be ordered by last_active DESC (most recent first)
|
||||||
|
assert sessions[0]['channel_id'] == "channel3"
|
||||||
|
assert sessions[1]['channel_id'] == "channel2"
|
||||||
|
assert sessions[2]['channel_id'] == "channel1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_sessions_after_activity(self, manager):
|
||||||
|
"""
|
||||||
|
Test that list_sessions ordering changes after activity updates.
|
||||||
|
|
||||||
|
What: Create sessions, update activity on older one, verify reordering
|
||||||
|
Why: Most active sessions should appear first
|
||||||
|
"""
|
||||||
|
await manager.save_session("channel1", "sess1")
|
||||||
|
await asyncio.sleep(1.1) # SQLite CURRENT_TIMESTAMP has 1-second precision
|
||||||
|
await manager.save_session("channel2", "sess2")
|
||||||
|
|
||||||
|
# Update activity on channel1 (older session)
|
||||||
|
await asyncio.sleep(1.1) # SQLite CURRENT_TIMESTAMP has 1-second precision
|
||||||
|
await manager.update_activity("channel1")
|
||||||
|
|
||||||
|
sessions = await manager.list_sessions()
|
||||||
|
|
||||||
|
# channel1 should now be first (most recent activity)
|
||||||
|
assert sessions[0]['channel_id'] == "channel1"
|
||||||
|
assert sessions[1]['channel_id'] == "channel2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionStats:
|
||||||
|
"""Test session statistics and analytics."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stats_empty_database(self, manager):
|
||||||
|
"""
|
||||||
|
Test statistics on empty database.
|
||||||
|
|
||||||
|
What: Get stats when no sessions exist
|
||||||
|
Why: Should return zeros/nulls, not error
|
||||||
|
"""
|
||||||
|
stats = await manager.get_stats()
|
||||||
|
|
||||||
|
assert stats['total_sessions'] == 0
|
||||||
|
assert stats['total_messages'] == 0
|
||||||
|
assert stats['active_projects'] == 0
|
||||||
|
assert stats['most_active_channel'] is None
|
||||||
|
assert stats['oldest_session'] is None
|
||||||
|
assert stats['newest_session'] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stats_with_sessions(self, manager):
|
||||||
|
"""
|
||||||
|
Test statistics calculation with multiple sessions.
|
||||||
|
|
||||||
|
What: Create sessions with varying activity, check stats
|
||||||
|
Why: Stats are used for monitoring and debugging
|
||||||
|
"""
|
||||||
|
await manager.save_session("channel1", "sess1", "project-a")
|
||||||
|
await manager.save_session("channel2", "sess2", "project-b")
|
||||||
|
await manager.save_session("channel3", "sess3", "project-a") # Same project
|
||||||
|
await manager.save_session("channel4", "sess4") # No project
|
||||||
|
|
||||||
|
# Add some message activity
|
||||||
|
for _ in range(5):
|
||||||
|
await manager.update_activity("channel1")
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
await manager.update_activity("channel2")
|
||||||
|
|
||||||
|
stats = await manager.get_stats()
|
||||||
|
|
||||||
|
assert stats['total_sessions'] == 4
|
||||||
|
assert stats['total_messages'] == 8 # 5 + 3 + 0 + 0
|
||||||
|
assert stats['active_projects'] == 2 # project-a and project-b (not counting None)
|
||||||
|
assert stats['most_active_channel'] == "channel1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stats_oldest_newest(self, manager):
|
||||||
|
"""
|
||||||
|
Test oldest and newest session tracking.
|
||||||
|
|
||||||
|
What: Create sessions in sequence, verify oldest/newest detection
|
||||||
|
Why: Useful for session lifecycle management
|
||||||
|
"""
|
||||||
|
await manager.save_session("channel1", "sess1")
|
||||||
|
await asyncio.sleep(1.1) # SQLite CURRENT_TIMESTAMP has 1-second precision
|
||||||
|
await manager.save_session("channel2", "sess2")
|
||||||
|
await asyncio.sleep(1.1) # SQLite CURRENT_TIMESTAMP has 1-second precision
|
||||||
|
await manager.save_session("channel3", "sess3")
|
||||||
|
|
||||||
|
stats = await manager.get_stats()
|
||||||
|
|
||||||
|
assert stats['oldest_session'] == "channel1"
|
||||||
|
assert stats['newest_session'] == "channel3"
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrentAccess:
|
||||||
|
"""Test concurrent access and thread safety."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_reads(self, manager):
|
||||||
|
"""
|
||||||
|
Test multiple concurrent read operations.
|
||||||
|
|
||||||
|
What: Create session, then read it from multiple coroutines simultaneously
|
||||||
|
Why: Bot handles multiple channels concurrently
|
||||||
|
"""
|
||||||
|
await manager.save_session("channel1", "sess1", "project-a")
|
||||||
|
|
||||||
|
# Launch 10 concurrent reads
|
||||||
|
tasks = [manager.get_session("channel1") for _ in range(10)]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# All reads should succeed
|
||||||
|
assert all(r is not None for r in results)
|
||||||
|
assert all(r['session_id'] == "sess1" for r in results)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_writes(self, manager):
|
||||||
|
"""
|
||||||
|
Test multiple concurrent write operations.
|
||||||
|
|
||||||
|
What: Create multiple sessions concurrently
|
||||||
|
Why: Multiple channels may save sessions simultaneously
|
||||||
|
"""
|
||||||
|
tasks = [
|
||||||
|
manager.save_session(f"channel{i}", f"sess{i}", "project-a")
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Verify all sessions were created
|
||||||
|
sessions = await manager.list_sessions()
|
||||||
|
assert len(sessions) == 10
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_updates_same_channel(self, manager):
|
||||||
|
"""
|
||||||
|
Test concurrent activity updates on the same channel.
|
||||||
|
|
||||||
|
What: Update activity on one channel from multiple coroutines
|
||||||
|
Why: Ensures atomicity and prevents race conditions
|
||||||
|
"""
|
||||||
|
await manager.save_session("channel1", "sess1")
|
||||||
|
|
||||||
|
# Update activity 20 times concurrently
|
||||||
|
tasks = [manager.update_activity("channel1") for _ in range(20)]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
session = await manager.get_session("channel1")
|
||||||
|
assert session['message_count'] == 20
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_mixed_operations(self, manager):
|
||||||
|
"""
|
||||||
|
Test mixed concurrent operations (reads, writes, updates).
|
||||||
|
|
||||||
|
What: Perform saves, gets, updates, and lists concurrently
|
||||||
|
Why: Real-world usage has mixed concurrent operations
|
||||||
|
"""
|
||||||
|
# Create initial session
|
||||||
|
await manager.save_session("channel1", "sess1")
|
||||||
|
|
||||||
|
# Mix of operations
|
||||||
|
tasks = [
|
||||||
|
manager.get_session("channel1"),
|
||||||
|
manager.update_activity("channel1"),
|
||||||
|
manager.save_session("channel2", "sess2"),
|
||||||
|
manager.list_sessions(),
|
||||||
|
manager.get_session("channel2"),
|
||||||
|
manager.update_activity("channel1"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# All should complete without error
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Verify final state
|
||||||
|
session1 = await manager.get_session("channel1")
|
||||||
|
session2 = await manager.get_session("channel2")
|
||||||
|
|
||||||
|
assert session1 is not None
|
||||||
|
assert session2 is not None
|
||||||
|
assert session1['message_count'] == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextManager:
|
||||||
|
"""Test async context manager usage."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_manager_cleanup(self, temp_db):
|
||||||
|
"""
|
||||||
|
Test that context manager properly closes database connection.
|
||||||
|
|
||||||
|
What: Use SessionManager with async context manager, verify cleanup
|
||||||
|
Why: Proper resource cleanup prevents database lock issues
|
||||||
|
"""
|
||||||
|
manager = SessionManager(db_path=temp_db)
|
||||||
|
|
||||||
|
async with manager:
|
||||||
|
await manager.save_session("channel1", "sess1")
|
||||||
|
assert manager._db is not None
|
||||||
|
|
||||||
|
# Connection should be closed after exiting context
|
||||||
|
assert manager._db is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_manager_exception_handling(self, temp_db):
|
||||||
|
"""
|
||||||
|
Test that context manager closes connection even on exception.
|
||||||
|
|
||||||
|
What: Raise exception inside context, verify cleanup still happens
|
||||||
|
Why: Resources must be cleaned up even on errors
|
||||||
|
"""
|
||||||
|
manager = SessionManager(db_path=temp_db)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with manager:
|
||||||
|
await manager.save_session("channel1", "sess1")
|
||||||
|
raise ValueError("Test exception")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Connection should still be closed
|
||||||
|
assert manager._db is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Test edge cases and error conditions."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_channel_id(self, manager):
|
||||||
|
"""
|
||||||
|
Test handling of empty channel_id.
|
||||||
|
|
||||||
|
What: Save/get session with empty string channel_id
|
||||||
|
Why: Should handle edge case gracefully
|
||||||
|
"""
|
||||||
|
await manager.save_session("", "sess1", "project-a")
|
||||||
|
session = await manager.get_session("")
|
||||||
|
assert session is not None
|
||||||
|
assert session['channel_id'] == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_very_long_ids(self, manager):
|
||||||
|
"""
|
||||||
|
Test handling of very long channel and session IDs.
|
||||||
|
|
||||||
|
What: Save session with 1000-character IDs
|
||||||
|
Why: Ensure no unexpected length limits
|
||||||
|
"""
|
||||||
|
long_channel = "c" * 1000
|
||||||
|
long_session = "s" * 1000
|
||||||
|
|
||||||
|
await manager.save_session(long_channel, long_session)
|
||||||
|
session = await manager.get_session(long_channel)
|
||||||
|
|
||||||
|
assert session is not None
|
||||||
|
assert session['channel_id'] == long_channel
|
||||||
|
assert session['session_id'] == long_session
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_special_characters_in_ids(self, manager):
|
||||||
|
"""
|
||||||
|
Test handling of special characters in IDs.
|
||||||
|
|
||||||
|
What: Use IDs with quotes, newlines, unicode
|
||||||
|
Why: Ensure proper SQL escaping
|
||||||
|
"""
|
||||||
|
special_channel = "channel'with\"quotes\nand\ttabs"
|
||||||
|
special_session = "sess🎉with📊unicode"
|
||||||
|
|
||||||
|
await manager.save_session(special_channel, special_session)
|
||||||
|
session = await manager.get_session(special_channel)
|
||||||
|
|
||||||
|
assert session is not None
|
||||||
|
assert session['channel_id'] == special_channel
|
||||||
|
assert session['session_id'] == special_session
|
||||||
|
|
||||||
|
|
||||||
|
# Performance test (optional, can be slow)
|
||||||
|
class TestPerformance:
|
||||||
|
"""Test performance characteristics (can be skipped for quick tests)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.slow
|
||||||
|
async def test_large_number_of_sessions(self, manager):
|
||||||
|
"""
|
||||||
|
Test handling of large number of sessions.
|
||||||
|
|
||||||
|
What: Create 1000 sessions and verify operations remain fast
|
||||||
|
Why: Bot may accumulate many sessions over time
|
||||||
|
"""
|
||||||
|
# Create 1000 sessions
|
||||||
|
for i in range(1000):
|
||||||
|
await manager.save_session(f"channel{i}", f"sess{i}", f"project{i % 10}")
|
||||||
|
|
||||||
|
# List should still be reasonably fast
|
||||||
|
sessions = await manager.list_sessions()
|
||||||
|
assert len(sessions) == 1000
|
||||||
|
|
||||||
|
# Stats should work
|
||||||
|
stats = await manager.get_stats()
|
||||||
|
assert stats['total_sessions'] == 1000
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run tests with pytest
|
||||||
|
import pytest
|
||||||
|
pytest.main([__file__, "-v", "-s"])
|
||||||
Loading…
Reference in New Issue
Block a user