Implement comprehensive rate limiting system and item spawn configuration
Major Features Added: - Complete token bucket rate limiting for IRC commands and web interface - Per-user rate tracking with category-based limits (Basic, Gameplay, Management, Admin, Web) - Admin commands for rate limit management (\!rate_stats, \!rate_user, \!rate_unban, \!rate_reset) - Automatic violation tracking and temporary bans with cleanup - Global item spawn multiplier system with 75% spawn rate reduction - Central admin configuration system (config.py) - One-command bot startup script (start_petbot.sh) Rate Limiting: - Token bucket algorithm with burst capacity and refill rates - Category limits: Basic (20/min), Gameplay (10/min), Management (5/min), Web (60/min) - Graceful violation handling with user-friendly error messages - Admin exemption and override capabilities - Background cleanup of old violations and expired bans Item Spawn System: - Added global_spawn_multiplier to config/items.json for easy adjustment - Reduced all individual spawn rates by 75% (multiplied by 0.25) - Admins can fine-tune both global multiplier and individual item rates - Game engine integration applies multiplier to all spawn calculations Infrastructure: - Single admin user configuration in config.py - Enhanced startup script with dependency management and verification - Updated documentation and help system with rate limiting guide - Comprehensive test suite for rate limiting functionality Security: - Rate limiting protects against command spam and abuse - IP-based tracking for web interface requests - Proper error handling and status codes (429 for rate limits) 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
f8ac661cd1
commit
915aa00bea
28 changed files with 5730 additions and 57 deletions
458
src/backup_manager.py
Normal file
458
src/backup_manager.py
Normal file
|
|
@ -0,0 +1,458 @@
|
|||
import os
|
||||
import shutil
|
||||
import sqlite3
|
||||
import gzip
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
import asyncio
|
||||
import aiosqlite
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
|
||||
class BackupManager:
|
||||
def __init__(self, db_path: str = "data/petbot.db", backup_dir: str = "backups"):
|
||||
self.db_path = db_path
|
||||
self.backup_dir = Path(backup_dir)
|
||||
self.backup_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Backup configuration
|
||||
self.max_daily_backups = 7 # Keep 7 daily backups
|
||||
self.max_weekly_backups = 4 # Keep 4 weekly backups
|
||||
self.max_monthly_backups = 12 # Keep 12 monthly backups
|
||||
|
||||
# Setup logging
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def create_backup(self, backup_type: str = "manual", compress: bool = True) -> Dict:
|
||||
"""Create a database backup with optional compression."""
|
||||
try:
|
||||
# Check if database exists
|
||||
if not os.path.exists(self.db_path):
|
||||
return {"success": False, "error": "Database file not found"}
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_filename = f"petbot_backup_{backup_type}_{timestamp}.db"
|
||||
|
||||
if compress:
|
||||
backup_filename += ".gz"
|
||||
|
||||
backup_path = self.backup_dir / backup_filename
|
||||
|
||||
# Create the backup
|
||||
if compress:
|
||||
await self._create_compressed_backup(backup_path)
|
||||
else:
|
||||
await self._create_regular_backup(backup_path)
|
||||
|
||||
# Get backup info
|
||||
backup_info = await self._get_backup_info(backup_path)
|
||||
|
||||
# Log the backup
|
||||
self.logger.info(f"Backup created: {backup_filename} ({backup_info['size_mb']:.1f}MB)")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"backup_path": str(backup_path),
|
||||
"backup_filename": backup_filename,
|
||||
"backup_type": backup_type,
|
||||
"timestamp": timestamp,
|
||||
"compressed": compress,
|
||||
"size_mb": backup_info["size_mb"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Backup creation failed: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _create_regular_backup(self, backup_path: Path):
|
||||
"""Create a regular SQLite backup using the backup API."""
|
||||
def backup_db():
|
||||
# Use SQLite backup API for consistent backup
|
||||
source_conn = sqlite3.connect(self.db_path)
|
||||
backup_conn = sqlite3.connect(str(backup_path))
|
||||
|
||||
# Perform the backup
|
||||
source_conn.backup(backup_conn)
|
||||
|
||||
# Close connections
|
||||
source_conn.close()
|
||||
backup_conn.close()
|
||||
|
||||
# Run the backup in a thread to avoid blocking
|
||||
await asyncio.get_event_loop().run_in_executor(None, backup_db)
|
||||
|
||||
async def _create_compressed_backup(self, backup_path: Path):
|
||||
"""Create a compressed backup."""
|
||||
def backup_and_compress():
|
||||
# First create temporary uncompressed backup
|
||||
temp_backup = backup_path.with_suffix('.tmp')
|
||||
|
||||
# Use SQLite backup API
|
||||
source_conn = sqlite3.connect(self.db_path)
|
||||
backup_conn = sqlite3.connect(str(temp_backup))
|
||||
source_conn.backup(backup_conn)
|
||||
source_conn.close()
|
||||
backup_conn.close()
|
||||
|
||||
# Compress the backup
|
||||
with open(temp_backup, 'rb') as f_in:
|
||||
with gzip.open(backup_path, 'wb') as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
|
||||
# Remove temporary file
|
||||
temp_backup.unlink()
|
||||
|
||||
await asyncio.get_event_loop().run_in_executor(None, backup_and_compress)
|
||||
|
||||
async def _get_backup_info(self, backup_path: Path) -> Dict:
|
||||
"""Get information about a backup file."""
|
||||
stat = backup_path.stat()
|
||||
|
||||
return {
|
||||
"size_bytes": stat.st_size,
|
||||
"size_mb": stat.st_size / (1024 * 1024),
|
||||
"created_at": datetime.fromtimestamp(stat.st_mtime),
|
||||
"compressed": backup_path.suffix == '.gz'
|
||||
}
|
||||
|
||||
async def list_backups(self) -> List[Dict]:
|
||||
"""List all available backups with metadata."""
|
||||
backups = []
|
||||
|
||||
for backup_file in self.backup_dir.glob("petbot_backup_*.db*"):
|
||||
try:
|
||||
info = await self._get_backup_info(backup_file)
|
||||
|
||||
# Parse backup filename for metadata
|
||||
filename = backup_file.name
|
||||
parts = filename.replace('.gz', '').replace('.db', '').split('_')
|
||||
|
||||
if len(parts) >= 4:
|
||||
backup_type = parts[2]
|
||||
timestamp = parts[3]
|
||||
|
||||
backups.append({
|
||||
"filename": filename,
|
||||
"path": str(backup_file),
|
||||
"type": backup_type,
|
||||
"timestamp": timestamp,
|
||||
"created_at": info["created_at"],
|
||||
"size_mb": info["size_mb"],
|
||||
"compressed": info["compressed"]
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error reading backup {backup_file}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by creation time (newest first)
|
||||
backups.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
return backups
|
||||
|
||||
async def restore_backup(self, backup_filename: str, target_path: str = None) -> Dict:
|
||||
"""Restore a database from backup."""
|
||||
try:
|
||||
backup_path = self.backup_dir / backup_filename
|
||||
|
||||
if not backup_path.exists():
|
||||
return {"success": False, "error": "Backup file not found"}
|
||||
|
||||
target_path = target_path or self.db_path
|
||||
|
||||
# Create backup of current database before restore
|
||||
current_backup = await self.create_backup("pre_restore", compress=True)
|
||||
if not current_backup["success"]:
|
||||
return {"success": False, "error": "Failed to backup current database"}
|
||||
|
||||
# Restore the backup
|
||||
if backup_path.suffix == '.gz':
|
||||
await self._restore_compressed_backup(backup_path, target_path)
|
||||
else:
|
||||
await self._restore_regular_backup(backup_path, target_path)
|
||||
|
||||
# Verify the restored database
|
||||
verification = await self._verify_database(target_path)
|
||||
if not verification["success"]:
|
||||
return {"success": False, "error": f"Restored database verification failed: {verification['error']}"}
|
||||
|
||||
self.logger.info(f"Database restored from backup: {backup_filename}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"backup_filename": backup_filename,
|
||||
"target_path": target_path,
|
||||
"current_backup": current_backup["backup_filename"],
|
||||
"tables_verified": verification["table_count"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Restore failed: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _restore_regular_backup(self, backup_path: Path, target_path: str):
|
||||
"""Restore from regular backup."""
|
||||
def restore():
|
||||
shutil.copy2(backup_path, target_path)
|
||||
|
||||
await asyncio.get_event_loop().run_in_executor(None, restore)
|
||||
|
||||
async def _restore_compressed_backup(self, backup_path: Path, target_path: str):
|
||||
"""Restore from compressed backup."""
|
||||
def restore():
|
||||
with gzip.open(backup_path, 'rb') as f_in:
|
||||
with open(target_path, 'wb') as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
|
||||
await asyncio.get_event_loop().run_in_executor(None, restore)
|
||||
|
||||
async def _verify_database(self, db_path: str) -> Dict:
|
||||
"""Verify database integrity and structure."""
|
||||
try:
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
# Check database integrity
|
||||
cursor = await db.execute("PRAGMA integrity_check")
|
||||
integrity_result = await cursor.fetchone()
|
||||
|
||||
if integrity_result[0] != "ok":
|
||||
return {"success": False, "error": f"Database integrity check failed: {integrity_result[0]}"}
|
||||
|
||||
# Count tables
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM sqlite_master WHERE type='table'")
|
||||
table_count = (await cursor.fetchone())[0]
|
||||
|
||||
# Basic table existence check
|
||||
required_tables = ["players", "pets", "pet_species", "moves", "items"]
|
||||
for table in required_tables:
|
||||
cursor = await db.execute(f"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", (table,))
|
||||
exists = (await cursor.fetchone())[0]
|
||||
if not exists:
|
||||
return {"success": False, "error": f"Required table '{table}' not found"}
|
||||
|
||||
return {"success": True, "table_count": table_count}
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def cleanup_old_backups(self) -> Dict:
|
||||
"""Remove old backups based on retention policy."""
|
||||
try:
|
||||
backups = await self.list_backups()
|
||||
|
||||
# Group backups by type
|
||||
daily_backups = [b for b in backups if b["type"] in ["daily", "manual"]]
|
||||
weekly_backups = [b for b in backups if b["type"] == "weekly"]
|
||||
monthly_backups = [b for b in backups if b["type"] == "monthly"]
|
||||
|
||||
cleaned_count = 0
|
||||
|
||||
# Clean daily backups (keep most recent)
|
||||
if len(daily_backups) > self.max_daily_backups:
|
||||
old_daily = daily_backups[self.max_daily_backups:]
|
||||
for backup in old_daily:
|
||||
await self._remove_backup(backup["path"])
|
||||
cleaned_count += 1
|
||||
|
||||
# Clean weekly backups
|
||||
if len(weekly_backups) > self.max_weekly_backups:
|
||||
old_weekly = weekly_backups[self.max_weekly_backups:]
|
||||
for backup in old_weekly:
|
||||
await self._remove_backup(backup["path"])
|
||||
cleaned_count += 1
|
||||
|
||||
# Clean monthly backups
|
||||
if len(monthly_backups) > self.max_monthly_backups:
|
||||
old_monthly = monthly_backups[self.max_monthly_backups:]
|
||||
for backup in old_monthly:
|
||||
await self._remove_backup(backup["path"])
|
||||
cleaned_count += 1
|
||||
|
||||
self.logger.info(f"Cleaned up {cleaned_count} old backups")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"cleaned_count": cleaned_count,
|
||||
"remaining_backups": len(backups) - cleaned_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Cleanup failed: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _remove_backup(self, backup_path: str):
|
||||
"""Remove a backup file."""
|
||||
try:
|
||||
Path(backup_path).unlink()
|
||||
self.logger.debug(f"Removed backup: {backup_path}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to remove backup {backup_path}: {e}")
|
||||
|
||||
async def get_backup_stats(self) -> Dict:
|
||||
"""Get statistics about backups."""
|
||||
try:
|
||||
backups = await self.list_backups()
|
||||
|
||||
if not backups:
|
||||
return {
|
||||
"success": True,
|
||||
"total_backups": 0,
|
||||
"total_size_mb": 0,
|
||||
"oldest_backup": None,
|
||||
"newest_backup": None,
|
||||
"by_type": {}
|
||||
}
|
||||
|
||||
total_size = sum(b["size_mb"] for b in backups)
|
||||
oldest = min(backups, key=lambda x: x["created_at"])
|
||||
newest = max(backups, key=lambda x: x["created_at"])
|
||||
|
||||
# Group by type
|
||||
by_type = {}
|
||||
for backup in backups:
|
||||
backup_type = backup["type"]
|
||||
if backup_type not in by_type:
|
||||
by_type[backup_type] = {"count": 0, "size_mb": 0}
|
||||
by_type[backup_type]["count"] += 1
|
||||
by_type[backup_type]["size_mb"] += backup["size_mb"]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"total_backups": len(backups),
|
||||
"total_size_mb": round(total_size, 1),
|
||||
"oldest_backup": oldest["created_at"],
|
||||
"newest_backup": newest["created_at"],
|
||||
"by_type": by_type
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get backup stats: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def export_database_structure(self) -> Dict:
|
||||
"""Export database schema for documentation/analysis."""
|
||||
try:
|
||||
structure = {
|
||||
"export_time": datetime.now().isoformat(),
|
||||
"database_path": self.db_path,
|
||||
"tables": {}
|
||||
}
|
||||
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Get all tables
|
||||
cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
tables = await cursor.fetchall()
|
||||
|
||||
for table_name in [t[0] for t in tables]:
|
||||
# Get table info
|
||||
cursor = await db.execute(f"PRAGMA table_info({table_name})")
|
||||
columns = await cursor.fetchall()
|
||||
|
||||
# Get row count
|
||||
cursor = await db.execute(f"SELECT COUNT(*) FROM {table_name}")
|
||||
row_count = (await cursor.fetchone())[0]
|
||||
|
||||
structure["tables"][table_name] = {
|
||||
"columns": [
|
||||
{
|
||||
"name": col[1],
|
||||
"type": col[2],
|
||||
"not_null": bool(col[3]),
|
||||
"default": col[4],
|
||||
"primary_key": bool(col[5])
|
||||
}
|
||||
for col in columns
|
||||
],
|
||||
"row_count": row_count
|
||||
}
|
||||
|
||||
# Save structure to file
|
||||
structure_path = self.backup_dir / f"database_structure_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
with open(structure_path, 'w') as f:
|
||||
json.dump(structure, f, indent=2, default=str)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"structure_path": str(structure_path),
|
||||
"table_count": len(structure["tables"]),
|
||||
"total_rows": sum(table["row_count"] for table in structure["tables"].values())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to export database structure: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
# Scheduler for automated backups
|
||||
class BackupScheduler:
|
||||
def __init__(self, backup_manager: BackupManager):
|
||||
self.backup_manager = backup_manager
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.running = False
|
||||
|
||||
async def start_scheduler(self):
|
||||
"""Start the backup scheduler."""
|
||||
self.running = True
|
||||
self.logger.info("Backup scheduler started")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
await self._check_and_create_backups()
|
||||
await asyncio.sleep(3600) # Check every hour
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Scheduler error: {str(e)}")
|
||||
await asyncio.sleep(3600) # Wait before retrying
|
||||
|
||||
def stop_scheduler(self):
|
||||
"""Stop the backup scheduler."""
|
||||
self.running = False
|
||||
self.logger.info("Backup scheduler stopped")
|
||||
|
||||
async def _check_and_create_backups(self):
|
||||
"""Check if backups are needed and create them."""
|
||||
now = datetime.now()
|
||||
|
||||
# Check daily backup (every 24 hours)
|
||||
if await self._should_create_backup("daily", hours=24):
|
||||
result = await self.backup_manager.create_backup("daily", compress=True)
|
||||
if result["success"]:
|
||||
self.logger.info(f"Daily backup created: {result['backup_filename']}")
|
||||
|
||||
# Check weekly backup (every 7 days)
|
||||
if await self._should_create_backup("weekly", days=7):
|
||||
result = await self.backup_manager.create_backup("weekly", compress=True)
|
||||
if result["success"]:
|
||||
self.logger.info(f"Weekly backup created: {result['backup_filename']}")
|
||||
|
||||
# Check monthly backup (every 30 days)
|
||||
if await self._should_create_backup("monthly", days=30):
|
||||
result = await self.backup_manager.create_backup("monthly", compress=True)
|
||||
if result["success"]:
|
||||
self.logger.info(f"Monthly backup created: {result['backup_filename']}")
|
||||
|
||||
# Cleanup old backups
|
||||
await self.backup_manager.cleanup_old_backups()
|
||||
|
||||
async def _should_create_backup(self, backup_type: str, hours: int = 0, days: int = 0) -> bool:
|
||||
"""Check if a backup of the specified type should be created."""
|
||||
try:
|
||||
backups = await self.backup_manager.list_backups()
|
||||
|
||||
# Find most recent backup of this type
|
||||
type_backups = [b for b in backups if b["type"] == backup_type]
|
||||
|
||||
if not type_backups:
|
||||
return True # No backups of this type exist
|
||||
|
||||
most_recent = max(type_backups, key=lambda x: x["created_at"])
|
||||
time_since = datetime.now() - most_recent["created_at"]
|
||||
|
||||
required_delta = timedelta(hours=hours, days=days)
|
||||
|
||||
return time_since >= required_delta
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking backup schedule: {str(e)}")
|
||||
return False
|
||||
|
|
@ -365,6 +365,9 @@ class GameEngine:
|
|||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
# Get global spawn multiplier from config
|
||||
global_multiplier = items_data.get("_config", {}).get("global_spawn_multiplier", 1.0)
|
||||
|
||||
# Get all possible items for this location
|
||||
available_items = []
|
||||
location_name = location["name"].lower().replace(" ", "_")
|
||||
|
|
@ -375,22 +378,25 @@ class GameEngine:
|
|||
if "locations" in item:
|
||||
item_locations = item["locations"]
|
||||
if "all" in item_locations or location_name in item_locations:
|
||||
available_items.append(item)
|
||||
# Apply global multiplier to spawn rate
|
||||
item_copy = item.copy()
|
||||
item_copy["effective_spawn_rate"] = item.get("spawn_rate", 0.1) * global_multiplier
|
||||
available_items.append(item_copy)
|
||||
|
||||
if not available_items:
|
||||
return None
|
||||
|
||||
# Calculate total spawn rates for this location
|
||||
total_rate = sum(item.get("spawn_rate", 0.1) for item in available_items)
|
||||
# Calculate total spawn rates for this location (using effective rates)
|
||||
total_rate = sum(item.get("effective_spawn_rate", 0.1) for item in available_items)
|
||||
|
||||
# 30% base chance of finding an item
|
||||
if random.random() > 0.3:
|
||||
return None
|
||||
|
||||
# Choose item based on spawn rates
|
||||
# Choose item based on effective spawn rates (with global multiplier applied)
|
||||
chosen_item = random.choices(
|
||||
available_items,
|
||||
weights=[item.get("spawn_rate", 0.1) for item in available_items]
|
||||
weights=[item.get("effective_spawn_rate", 0.1) for item in available_items]
|
||||
)[0]
|
||||
|
||||
# Add item to player's inventory
|
||||
|
|
|
|||
395
src/irc_connection_manager.py
Normal file
395
src/irc_connection_manager.py
Normal file
|
|
@ -0,0 +1,395 @@
|
|||
import asyncio
|
||||
import socket
|
||||
import time
|
||||
import logging
|
||||
import random
|
||||
from enum import Enum
|
||||
from typing import Optional, Callable, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class ConnectionState(Enum):
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
AUTHENTICATED = "authenticated"
|
||||
JOINED = "joined"
|
||||
RECONNECTING = "reconnecting"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class IRCConnectionManager:
|
||||
"""
|
||||
Robust IRC connection manager with automatic reconnection,
|
||||
health monitoring, and exponential backoff.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], bot_instance=None):
|
||||
self.config = config
|
||||
self.bot = bot_instance
|
||||
self.socket = None
|
||||
self.state = ConnectionState.DISCONNECTED
|
||||
self.running = False
|
||||
|
||||
# Connection monitoring
|
||||
self.last_ping_time = 0
|
||||
self.last_pong_time = 0
|
||||
self.ping_interval = 60 # Send PING every 60 seconds
|
||||
self.ping_timeout = 120 # Expect PONG within 2 minutes
|
||||
|
||||
# Reconnection settings
|
||||
self.reconnect_attempts = 0
|
||||
self.max_reconnect_attempts = 50
|
||||
self.base_reconnect_delay = 1 # Start with 1 second
|
||||
self.max_reconnect_delay = 300 # Cap at 5 minutes
|
||||
self.reconnect_jitter = 0.1 # 10% jitter
|
||||
|
||||
# Connection tracking
|
||||
self.connection_start_time = None
|
||||
self.last_successful_connection = None
|
||||
self.total_reconnections = 0
|
||||
self.connection_failures = 0
|
||||
|
||||
# Event callbacks
|
||||
self.on_connect_callback = None
|
||||
self.on_disconnect_callback = None
|
||||
self.on_message_callback = None
|
||||
self.on_connection_lost_callback = None
|
||||
|
||||
# Health monitoring
|
||||
self.health_check_interval = 30 # Check health every 30 seconds
|
||||
self.health_check_task = None
|
||||
self.message_count = 0
|
||||
self.last_message_time = 0
|
||||
|
||||
# Setup logging
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
# Create console handler if none exists
|
||||
if not self.logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
self.logger.addHandler(handler)
|
||||
|
||||
def set_callbacks(self, on_connect=None, on_disconnect=None, on_message=None, on_connection_lost=None):
|
||||
"""Set callback functions for connection events."""
|
||||
self.on_connect_callback = on_connect
|
||||
self.on_disconnect_callback = on_disconnect
|
||||
self.on_message_callback = on_message
|
||||
self.on_connection_lost_callback = on_connection_lost
|
||||
|
||||
async def start(self):
|
||||
"""Start the connection manager."""
|
||||
if self.running:
|
||||
self.logger.warning("Connection manager is already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.logger.info("Starting IRC connection manager")
|
||||
|
||||
# Start health monitoring
|
||||
self.health_check_task = asyncio.create_task(self._health_monitor())
|
||||
|
||||
# Start connection loop
|
||||
await self._connection_loop()
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the connection manager and close connections."""
|
||||
self.running = False
|
||||
self.logger.info("Stopping IRC connection manager")
|
||||
|
||||
# Cancel health monitoring
|
||||
if self.health_check_task:
|
||||
self.health_check_task.cancel()
|
||||
try:
|
||||
await self.health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close socket
|
||||
await self._disconnect()
|
||||
|
||||
async def _connection_loop(self):
|
||||
"""Main connection loop with automatic reconnection."""
|
||||
while self.running:
|
||||
try:
|
||||
if self.state == ConnectionState.DISCONNECTED:
|
||||
await self._connect()
|
||||
|
||||
if self.state in [ConnectionState.CONNECTED, ConnectionState.AUTHENTICATED, ConnectionState.JOINED]:
|
||||
await self._handle_messages()
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in connection loop: {e}")
|
||||
await self._handle_connection_error(e)
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to IRC server with retry logic."""
|
||||
if self.reconnect_attempts >= self.max_reconnect_attempts:
|
||||
self.logger.error(f"Maximum reconnection attempts ({self.max_reconnect_attempts}) reached")
|
||||
self.state = ConnectionState.FAILED
|
||||
return
|
||||
|
||||
self.state = ConnectionState.CONNECTING
|
||||
self.connection_start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# Calculate reconnection delay with exponential backoff
|
||||
if self.reconnect_attempts > 0:
|
||||
delay = min(
|
||||
self.base_reconnect_delay * (2 ** self.reconnect_attempts),
|
||||
self.max_reconnect_delay
|
||||
)
|
||||
# Add jitter to prevent thundering herd
|
||||
jitter = delay * self.reconnect_jitter * random.random()
|
||||
delay += jitter
|
||||
|
||||
self.logger.info(f"Reconnection attempt {self.reconnect_attempts + 1}/{self.max_reconnect_attempts} after {delay:.1f}s delay")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Test basic connectivity first
|
||||
await self._test_connectivity()
|
||||
|
||||
# Create socket connection
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.settimeout(10) # 10 second timeout for connection
|
||||
|
||||
# Connect to server
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.socket.connect, (self.config["server"], self.config["port"])
|
||||
)
|
||||
|
||||
self.socket.settimeout(1) # Shorter timeout for message handling
|
||||
self.state = ConnectionState.CONNECTED
|
||||
|
||||
# Send IRC handshake
|
||||
await self._send_handshake()
|
||||
|
||||
# Reset reconnection counter on successful connection
|
||||
self.reconnect_attempts = 0
|
||||
self.last_successful_connection = datetime.now()
|
||||
self.total_reconnections += 1
|
||||
|
||||
self.logger.info(f"Successfully connected to {self.config['server']}:{self.config['port']}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Connection failed: {e}")
|
||||
self.connection_failures += 1
|
||||
self.reconnect_attempts += 1
|
||||
await self._disconnect()
|
||||
|
||||
if self.on_connection_lost_callback:
|
||||
await self.on_connection_lost_callback(e)
|
||||
|
||||
async def _test_connectivity(self):
|
||||
"""Test basic network connectivity to IRC server."""
|
||||
try:
|
||||
test_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
test_sock.settimeout(5)
|
||||
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, test_sock.connect, (self.config["server"], self.config["port"])
|
||||
)
|
||||
|
||||
test_sock.close()
|
||||
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Network connectivity test failed: {e}")
|
||||
|
||||
async def _send_handshake(self):
|
||||
"""Send IRC handshake messages."""
|
||||
nickname = self.config["nickname"]
|
||||
await self._send_raw(f"NICK {nickname}")
|
||||
await self._send_raw(f"USER {nickname} 0 * :{nickname}")
|
||||
|
||||
# Initialize ping tracking
|
||||
self.last_ping_time = time.time()
|
||||
self.last_pong_time = time.time()
|
||||
|
||||
async def _handle_messages(self):
|
||||
"""Handle incoming IRC messages."""
|
||||
try:
|
||||
data = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.socket.recv, 4096
|
||||
)
|
||||
|
||||
if not data:
|
||||
raise ConnectionError("Connection closed by server")
|
||||
|
||||
# Update message tracking
|
||||
self.message_count += 1
|
||||
self.last_message_time = time.time()
|
||||
|
||||
# Decode and process messages
|
||||
lines = data.decode('utf-8', errors='ignore').strip().split('\n')
|
||||
for line in lines:
|
||||
if line.strip():
|
||||
await self._process_line(line.strip())
|
||||
|
||||
except socket.timeout:
|
||||
# Timeout is expected, continue
|
||||
pass
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Message handling error: {e}")
|
||||
|
||||
async def _process_line(self, line):
|
||||
"""Process a single IRC line."""
|
||||
# Handle PING/PONG
|
||||
if line.startswith("PING"):
|
||||
pong_response = line.replace("PING", "PONG")
|
||||
await self._send_raw(pong_response)
|
||||
return
|
||||
|
||||
if line.startswith("PONG"):
|
||||
self.last_pong_time = time.time()
|
||||
return
|
||||
|
||||
# Handle connection completion
|
||||
if "376" in line or "422" in line: # End of MOTD
|
||||
if self.state == ConnectionState.CONNECTED:
|
||||
self.state = ConnectionState.AUTHENTICATED
|
||||
await self._send_raw(f"JOIN {self.config['channel']}")
|
||||
|
||||
# Handle successful channel join
|
||||
if " JOIN " in line and self.config["channel"] in line:
|
||||
if self.state == ConnectionState.AUTHENTICATED:
|
||||
self.state = ConnectionState.JOINED
|
||||
self.logger.info(f"Successfully joined {self.config['channel']}")
|
||||
|
||||
if self.on_connect_callback:
|
||||
await self.on_connect_callback()
|
||||
|
||||
# Handle nickname conflicts
|
||||
if "433" in line: # Nickname in use
|
||||
new_nickname = f"{self.config['nickname']}_"
|
||||
self.logger.warning(f"Nickname conflict, trying {new_nickname}")
|
||||
await self._send_raw(f"NICK {new_nickname}")
|
||||
|
||||
# Handle disconnection
|
||||
if "ERROR :Closing Link" in line:
|
||||
raise ConnectionError("Server closed connection")
|
||||
|
||||
# Forward message to callback
|
||||
if self.on_message_callback:
|
||||
await self.on_message_callback(line)
|
||||
|
||||
async def _send_raw(self, message):
|
||||
"""Send raw IRC message."""
|
||||
if not self.socket:
|
||||
raise ConnectionError("Not connected to IRC server")
|
||||
|
||||
try:
|
||||
full_message = f"{message}\r\n"
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.socket.send, full_message.encode('utf-8')
|
||||
)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Failed to send message: {e}")
|
||||
|
||||
async def send_message(self, target, message):
|
||||
"""Send a message to a channel or user."""
|
||||
if self.state != ConnectionState.JOINED:
|
||||
self.logger.warning(f"Cannot send message, not joined to channel (state: {self.state})")
|
||||
return False
|
||||
|
||||
try:
|
||||
await self._send_raw(f"PRIVMSG {target} :{message}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to send message to {target}: {e}")
|
||||
return False
|
||||
|
||||
async def _health_monitor(self):
|
||||
"""Monitor connection health and send periodic pings."""
|
||||
while self.running:
|
||||
try:
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
if self.state == ConnectionState.JOINED:
|
||||
await self._check_connection_health()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Health monitor error: {e}")
|
||||
|
||||
async def _check_connection_health(self):
|
||||
"""Check if connection is healthy and send pings as needed."""
|
||||
current_time = time.time()
|
||||
|
||||
# Send ping if interval has passed
|
||||
if current_time - self.last_ping_time > self.ping_interval:
|
||||
try:
|
||||
await self._send_raw(f"PING :health_check_{int(current_time)}")
|
||||
self.last_ping_time = current_time
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to send ping: {e}")
|
||||
raise ConnectionError("Health check ping failed")
|
||||
|
||||
# Check if we've received a pong recently
|
||||
if current_time - self.last_pong_time > self.ping_timeout:
|
||||
self.logger.warning("No PONG received within timeout period")
|
||||
raise ConnectionError("Ping timeout - connection appears dead")
|
||||
|
||||
async def _handle_connection_error(self, error):
|
||||
"""Handle connection errors and initiate reconnection."""
|
||||
self.logger.error(f"Connection error: {error}")
|
||||
|
||||
# Notify callback
|
||||
if self.on_disconnect_callback:
|
||||
await self.on_disconnect_callback(error)
|
||||
|
||||
# Disconnect and prepare for reconnection
|
||||
await self._disconnect()
|
||||
|
||||
# Set state for reconnection
|
||||
if self.running:
|
||||
self.state = ConnectionState.DISCONNECTED
|
||||
self.reconnect_attempts += 1
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from IRC server."""
|
||||
if self.socket:
|
||||
try:
|
||||
self.socket.close()
|
||||
except:
|
||||
pass
|
||||
self.socket = None
|
||||
|
||||
old_state = self.state
|
||||
self.state = ConnectionState.DISCONNECTED
|
||||
|
||||
if old_state != ConnectionState.DISCONNECTED:
|
||||
self.logger.info("Disconnected from IRC server")
|
||||
|
||||
def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""Get connection statistics."""
|
||||
uptime = None
|
||||
if self.connection_start_time:
|
||||
uptime = datetime.now() - self.connection_start_time
|
||||
|
||||
return {
|
||||
"state": self.state.value,
|
||||
"connected": self.state == ConnectionState.JOINED,
|
||||
"uptime": str(uptime) if uptime else None,
|
||||
"reconnect_attempts": self.reconnect_attempts,
|
||||
"total_reconnections": self.total_reconnections,
|
||||
"connection_failures": self.connection_failures,
|
||||
"last_successful_connection": self.last_successful_connection,
|
||||
"message_count": self.message_count,
|
||||
"last_message_time": datetime.fromtimestamp(self.last_message_time) if self.last_message_time else None,
|
||||
"last_ping_time": datetime.fromtimestamp(self.last_ping_time) if self.last_ping_time else None,
|
||||
"last_pong_time": datetime.fromtimestamp(self.last_pong_time) if self.last_pong_time else None
|
||||
}
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if bot is connected and ready."""
|
||||
return self.state == ConnectionState.JOINED
|
||||
|
||||
def get_state(self) -> ConnectionState:
|
||||
"""Get current connection state."""
|
||||
return self.state
|
||||
426
src/rate_limiter.py
Normal file
426
src/rate_limiter.py
Normal file
|
|
@ -0,0 +1,426 @@
|
|||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
|
||||
class CommandCategory(Enum):
|
||||
"""Categories of commands with different rate limits."""
|
||||
BASIC = "basic" # !help, !ping, !status
|
||||
GAMEPLAY = "gameplay" # !explore, !catch, !battle
|
||||
MANAGEMENT = "management" # !pets, !activate, !deactivate
|
||||
ADMIN = "admin" # !backup, !reload, !reconnect
|
||||
WEB = "web" # Web interface requests
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Token bucket rate limiter with per-user tracking and command categories.
|
||||
|
||||
Features:
|
||||
- Token bucket algorithm for smooth rate limiting
|
||||
- Per-user rate tracking
|
||||
- Different limits for different command categories
|
||||
- Burst capacity handling
|
||||
- Admin exemption
|
||||
- Detailed logging and monitoring
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Default rate limit configuration
|
||||
self.config = {
|
||||
"enabled": True,
|
||||
"categories": {
|
||||
CommandCategory.BASIC: {
|
||||
"requests_per_minute": 20,
|
||||
"burst_capacity": 5,
|
||||
"cooldown_seconds": 1
|
||||
},
|
||||
CommandCategory.GAMEPLAY: {
|
||||
"requests_per_minute": 10,
|
||||
"burst_capacity": 3,
|
||||
"cooldown_seconds": 3
|
||||
},
|
||||
CommandCategory.MANAGEMENT: {
|
||||
"requests_per_minute": 5,
|
||||
"burst_capacity": 2,
|
||||
"cooldown_seconds": 5
|
||||
},
|
||||
CommandCategory.ADMIN: {
|
||||
"requests_per_minute": 100,
|
||||
"burst_capacity": 10,
|
||||
"cooldown_seconds": 0
|
||||
},
|
||||
CommandCategory.WEB: {
|
||||
"requests_per_minute": 60,
|
||||
"burst_capacity": 10,
|
||||
"cooldown_seconds": 1
|
||||
}
|
||||
},
|
||||
"admin_users": ["megasconed"], # This will be overridden by bot initialization
|
||||
"global_limits": {
|
||||
"max_requests_per_minute": 200,
|
||||
"max_concurrent_users": 100
|
||||
},
|
||||
"violation_penalties": {
|
||||
"warning_threshold": 3,
|
||||
"temporary_ban_threshold": 10,
|
||||
"temporary_ban_duration": 300 # 5 minutes
|
||||
}
|
||||
}
|
||||
|
||||
# Override with provided config
|
||||
if config:
|
||||
self._update_config(config)
|
||||
|
||||
# Rate limiting state
|
||||
self.user_buckets: Dict[str, Dict] = {}
|
||||
self.global_stats = {
|
||||
"requests_this_minute": 0,
|
||||
"minute_start": time.time(),
|
||||
"active_users": set(),
|
||||
"total_requests": 0,
|
||||
"blocked_requests": 0
|
||||
}
|
||||
|
||||
# Violation tracking
|
||||
self.violations: Dict[str, Dict] = {}
|
||||
self.banned_users: Dict[str, float] = {} # user -> ban_end_time
|
||||
|
||||
# Background cleanup task
|
||||
self.cleanup_task = None
|
||||
self.start_cleanup_task()
|
||||
|
||||
def _update_config(self, config: Dict):
|
||||
"""Update configuration with provided values."""
|
||||
def deep_update(base_dict, update_dict):
|
||||
for key, value in update_dict.items():
|
||||
if key in base_dict and isinstance(base_dict[key], dict) and isinstance(value, dict):
|
||||
deep_update(base_dict[key], value)
|
||||
else:
|
||||
base_dict[key] = value
|
||||
|
||||
deep_update(self.config, config)
|
||||
|
||||
def start_cleanup_task(self):
|
||||
"""Start background cleanup task."""
|
||||
if self.cleanup_task is None or self.cleanup_task.done():
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Background task to clean up old data."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(60) # Cleanup every minute
|
||||
await self._cleanup_old_data()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in rate limiter cleanup: {e}")
|
||||
|
||||
async def _cleanup_old_data(self):
|
||||
"""Clean up old rate limiting data."""
|
||||
current_time = time.time()
|
||||
|
||||
# Clean up old user buckets (inactive for 10 minutes)
|
||||
inactive_threshold = current_time - 600
|
||||
inactive_users = [
|
||||
user for user, data in self.user_buckets.items()
|
||||
if data.get("last_request", 0) < inactive_threshold
|
||||
]
|
||||
|
||||
for user in inactive_users:
|
||||
del self.user_buckets[user]
|
||||
|
||||
# Clean up old violations (older than 1 hour)
|
||||
violation_threshold = current_time - 3600
|
||||
old_violations = [
|
||||
user for user, data in self.violations.items()
|
||||
if data.get("last_violation", 0) < violation_threshold
|
||||
]
|
||||
|
||||
for user in old_violations:
|
||||
del self.violations[user]
|
||||
|
||||
# Clean up expired bans
|
||||
expired_bans = [
|
||||
user for user, ban_end in self.banned_users.items()
|
||||
if current_time > ban_end
|
||||
]
|
||||
|
||||
for user in expired_bans:
|
||||
del self.banned_users[user]
|
||||
self.logger.info(f"Temporary ban expired for user: {user}")
|
||||
|
||||
# Reset global stats every minute
|
||||
if current_time - self.global_stats["minute_start"] >= 60:
|
||||
self.global_stats["requests_this_minute"] = 0
|
||||
self.global_stats["minute_start"] = current_time
|
||||
self.global_stats["active_users"].clear()
|
||||
|
||||
async def check_rate_limit(self, user: str, category: CommandCategory,
|
||||
command: str = None) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if a request is allowed under rate limiting.
|
||||
|
||||
Returns:
|
||||
(allowed: bool, message: Optional[str])
|
||||
"""
|
||||
if not self.config["enabled"]:
|
||||
return True, None
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# Check if user is temporarily banned
|
||||
if user in self.banned_users:
|
||||
if current_time < self.banned_users[user]:
|
||||
remaining = int(self.banned_users[user] - current_time)
|
||||
return False, f"⛔ You are temporarily banned for {remaining} seconds due to rate limit violations."
|
||||
else:
|
||||
del self.banned_users[user]
|
||||
|
||||
# Admin exemption
|
||||
if user.lower() in [admin.lower() for admin in self.config["admin_users"]]:
|
||||
return True, None
|
||||
|
||||
# Check global limits
|
||||
if not await self._check_global_limits():
|
||||
return False, "🚫 Server is currently overloaded. Please try again later."
|
||||
|
||||
# Check per-user rate limit
|
||||
allowed, message = await self._check_user_rate_limit(user, category, current_time)
|
||||
|
||||
if allowed:
|
||||
# Update global stats
|
||||
self.global_stats["requests_this_minute"] += 1
|
||||
self.global_stats["total_requests"] += 1
|
||||
self.global_stats["active_users"].add(user)
|
||||
else:
|
||||
# Track violation
|
||||
await self._track_violation(user, category, command)
|
||||
self.global_stats["blocked_requests"] += 1
|
||||
|
||||
return allowed, message
|
||||
|
||||
async def _check_global_limits(self) -> bool:
|
||||
"""Check global rate limits."""
|
||||
# Check requests per minute
|
||||
if (self.global_stats["requests_this_minute"] >=
|
||||
self.config["global_limits"]["max_requests_per_minute"]):
|
||||
return False
|
||||
|
||||
# Check concurrent users
|
||||
if (len(self.global_stats["active_users"]) >=
|
||||
self.config["global_limits"]["max_concurrent_users"]):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _check_user_rate_limit(self, user: str, category: CommandCategory,
|
||||
current_time: float) -> Tuple[bool, Optional[str]]:
|
||||
"""Check per-user rate limit using token bucket algorithm."""
|
||||
category_config = self.config["categories"][category]
|
||||
|
||||
# Get or create user bucket
|
||||
if user not in self.user_buckets:
|
||||
self.user_buckets[user] = {
|
||||
"tokens": category_config["burst_capacity"],
|
||||
"last_refill": current_time,
|
||||
"last_request": current_time
|
||||
}
|
||||
|
||||
bucket = self.user_buckets[user]
|
||||
|
||||
# Calculate tokens to add (refill rate)
|
||||
time_passed = current_time - bucket["last_refill"]
|
||||
tokens_per_second = category_config["requests_per_minute"] / 60.0
|
||||
tokens_to_add = time_passed * tokens_per_second
|
||||
|
||||
# Refill bucket (up to burst capacity)
|
||||
bucket["tokens"] = min(
|
||||
category_config["burst_capacity"],
|
||||
bucket["tokens"] + tokens_to_add
|
||||
)
|
||||
bucket["last_refill"] = current_time
|
||||
|
||||
# Check if request is allowed
|
||||
if bucket["tokens"] >= 1:
|
||||
bucket["tokens"] -= 1
|
||||
bucket["last_request"] = current_time
|
||||
return True, None
|
||||
else:
|
||||
# Calculate cooldown time
|
||||
time_since_last = current_time - bucket["last_request"]
|
||||
cooldown = category_config["cooldown_seconds"]
|
||||
|
||||
if time_since_last < cooldown:
|
||||
remaining = int(cooldown - time_since_last)
|
||||
return False, f"⏱️ Rate limit exceeded. Please wait {remaining} seconds before using {category.value} commands."
|
||||
else:
|
||||
# Allow if cooldown has passed
|
||||
bucket["tokens"] = category_config["burst_capacity"] - 1
|
||||
bucket["last_request"] = current_time
|
||||
return True, None
|
||||
|
||||
async def _track_violation(self, user: str, category: CommandCategory, command: str):
|
||||
"""Track rate limit violations for potential penalties."""
|
||||
current_time = time.time()
|
||||
|
||||
if user not in self.violations:
|
||||
self.violations[user] = {
|
||||
"count": 0,
|
||||
"last_violation": current_time,
|
||||
"categories": {}
|
||||
}
|
||||
|
||||
violation = self.violations[user]
|
||||
violation["count"] += 1
|
||||
violation["last_violation"] = current_time
|
||||
|
||||
# Track by category
|
||||
if category.value not in violation["categories"]:
|
||||
violation["categories"][category.value] = 0
|
||||
violation["categories"][category.value] += 1
|
||||
|
||||
# Check for penalties
|
||||
penalty_config = self.config["violation_penalties"]
|
||||
|
||||
if violation["count"] >= penalty_config["temporary_ban_threshold"]:
|
||||
# Temporary ban
|
||||
ban_duration = penalty_config["temporary_ban_duration"]
|
||||
self.banned_users[user] = current_time + ban_duration
|
||||
self.logger.warning(f"User {user} temporarily banned for {ban_duration}s due to rate limit violations")
|
||||
elif violation["count"] >= penalty_config["warning_threshold"]:
|
||||
# Warning threshold reached
|
||||
self.logger.warning(f"User {user} reached rate limit warning threshold ({violation['count']} violations)")
|
||||
|
||||
def get_user_stats(self, user: str) -> Dict:
|
||||
"""Get rate limiting stats for a specific user."""
|
||||
stats = {
|
||||
"user": user,
|
||||
"is_banned": user in self.banned_users,
|
||||
"ban_expires": None,
|
||||
"violations": 0,
|
||||
"buckets": {},
|
||||
"admin_exemption": user.lower() in [admin.lower() for admin in self.config["admin_users"]]
|
||||
}
|
||||
|
||||
# Ban info
|
||||
if stats["is_banned"]:
|
||||
stats["ban_expires"] = datetime.fromtimestamp(self.banned_users[user])
|
||||
|
||||
# Violation info
|
||||
if user in self.violations:
|
||||
stats["violations"] = self.violations[user]["count"]
|
||||
|
||||
# Bucket info
|
||||
if user in self.user_buckets:
|
||||
bucket = self.user_buckets[user]
|
||||
stats["buckets"] = {
|
||||
"tokens": round(bucket["tokens"], 2),
|
||||
"last_request": datetime.fromtimestamp(bucket["last_request"])
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def get_global_stats(self) -> Dict:
|
||||
"""Get global rate limiting statistics."""
|
||||
return {
|
||||
"enabled": self.config["enabled"],
|
||||
"requests_this_minute": self.global_stats["requests_this_minute"],
|
||||
"active_users": len(self.global_stats["active_users"]),
|
||||
"total_requests": self.global_stats["total_requests"],
|
||||
"blocked_requests": self.global_stats["blocked_requests"],
|
||||
"banned_users": len(self.banned_users),
|
||||
"tracked_users": len(self.user_buckets),
|
||||
"total_violations": sum(v["count"] for v in self.violations.values()),
|
||||
"config": self.config
|
||||
}
|
||||
|
||||
def is_user_banned(self, user: str) -> bool:
|
||||
"""Check if a user is currently banned."""
|
||||
if user not in self.banned_users:
|
||||
return False
|
||||
|
||||
if time.time() > self.banned_users[user]:
|
||||
del self.banned_users[user]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def unban_user(self, user: str) -> bool:
|
||||
"""Manually unban a user (admin function)."""
|
||||
if user in self.banned_users:
|
||||
del self.banned_users[user]
|
||||
self.logger.info(f"User {user} manually unbanned")
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset_user_violations(self, user: str) -> bool:
|
||||
"""Reset violations for a user (admin function)."""
|
||||
if user in self.violations:
|
||||
del self.violations[user]
|
||||
self.logger.info(f"Violations reset for user {user}")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the rate limiter and cleanup tasks."""
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await self.cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.logger.info("Rate limiter shutdown complete")
|
||||
|
||||
|
||||
# Command category mapping
|
||||
COMMAND_CATEGORIES = {
|
||||
# Basic commands
|
||||
"help": CommandCategory.BASIC,
|
||||
"ping": CommandCategory.BASIC,
|
||||
"status": CommandCategory.BASIC,
|
||||
"uptime": CommandCategory.BASIC,
|
||||
"connection_stats": CommandCategory.BASIC,
|
||||
|
||||
# Gameplay commands
|
||||
"start": CommandCategory.GAMEPLAY,
|
||||
"explore": CommandCategory.GAMEPLAY,
|
||||
"catch": CommandCategory.GAMEPLAY,
|
||||
"battle": CommandCategory.GAMEPLAY,
|
||||
"attack": CommandCategory.GAMEPLAY,
|
||||
"moves": CommandCategory.GAMEPLAY,
|
||||
"flee": CommandCategory.GAMEPLAY,
|
||||
"travel": CommandCategory.GAMEPLAY,
|
||||
"weather": CommandCategory.GAMEPLAY,
|
||||
"gym": CommandCategory.GAMEPLAY,
|
||||
|
||||
# Management commands
|
||||
"pets": CommandCategory.MANAGEMENT,
|
||||
"activate": CommandCategory.MANAGEMENT,
|
||||
"deactivate": CommandCategory.MANAGEMENT,
|
||||
"stats": CommandCategory.MANAGEMENT,
|
||||
"inventory": CommandCategory.MANAGEMENT,
|
||||
"use": CommandCategory.MANAGEMENT,
|
||||
"nickname": CommandCategory.MANAGEMENT,
|
||||
|
||||
# Admin commands
|
||||
"backup": CommandCategory.ADMIN,
|
||||
"restore": CommandCategory.ADMIN,
|
||||
"backups": CommandCategory.ADMIN,
|
||||
"backup_stats": CommandCategory.ADMIN,
|
||||
"backup_cleanup": CommandCategory.ADMIN,
|
||||
"reload": CommandCategory.ADMIN,
|
||||
"reconnect": CommandCategory.ADMIN,
|
||||
}
|
||||
|
||||
|
||||
def get_command_category(command: str) -> CommandCategory:
|
||||
"""Get the rate limiting category for a command."""
|
||||
return COMMAND_CATEGORIES.get(command.lower(), CommandCategory.GAMEPLAY)
|
||||
Loading…
Add table
Add a link
Reference in a new issue