Implement comprehensive rate limiting system and item spawn configuration

Major Features Added:
- Complete token bucket rate limiting for IRC commands and web interface
- Per-user rate tracking with category-based limits (Basic, Gameplay, Management, Admin, Web)
- Admin commands for rate limit management (\!rate_stats, \!rate_user, \!rate_unban, \!rate_reset)
- Automatic violation tracking and temporary bans with cleanup
- Global item spawn multiplier system with 75% spawn rate reduction
- Central admin configuration system (config.py)
- One-command bot startup script (start_petbot.sh)

Rate Limiting:
- Token bucket algorithm with burst capacity and refill rates
- Category limits: Basic (20/min), Gameplay (10/min), Management (5/min), Web (60/min)
- Graceful violation handling with user-friendly error messages
- Admin exemption and override capabilities
- Background cleanup of old violations and expired bans

Item Spawn System:
- Added global_spawn_multiplier to config/items.json for easy adjustment
- Reduced all individual spawn rates by 75% (multiplied by 0.25)
- Admins can fine-tune both global multiplier and individual item rates
- Game engine integration applies multiplier to all spawn calculations

Infrastructure:
- Single admin user configuration in config.py
- Enhanced startup script with dependency management and verification
- Updated documentation and help system with rate limiting guide
- Comprehensive test suite for rate limiting functionality

Security:
- Rate limiting protects against command spam and abuse
- IP-based tracking for web interface requests
- Proper error handling and status codes (429 for rate limits)

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
megaproxy 2025-07-15 20:10:43 +00:00
parent f8ac661cd1
commit 915aa00bea
28 changed files with 5730 additions and 57 deletions

458
src/backup_manager.py Normal file
View file

@ -0,0 +1,458 @@
import os
import shutil
import sqlite3
import gzip
import json
from datetime import datetime, timedelta
from typing import Dict, List, Optional
import asyncio
import aiosqlite
from pathlib import Path
import logging
class BackupManager:
def __init__(self, db_path: str = "data/petbot.db", backup_dir: str = "backups"):
self.db_path = db_path
self.backup_dir = Path(backup_dir)
self.backup_dir.mkdir(exist_ok=True)
# Backup configuration
self.max_daily_backups = 7 # Keep 7 daily backups
self.max_weekly_backups = 4 # Keep 4 weekly backups
self.max_monthly_backups = 12 # Keep 12 monthly backups
# Setup logging
self.logger = logging.getLogger(__name__)
async def create_backup(self, backup_type: str = "manual", compress: bool = True) -> Dict:
"""Create a database backup with optional compression."""
try:
# Check if database exists
if not os.path.exists(self.db_path):
return {"success": False, "error": "Database file not found"}
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_filename = f"petbot_backup_{backup_type}_{timestamp}.db"
if compress:
backup_filename += ".gz"
backup_path = self.backup_dir / backup_filename
# Create the backup
if compress:
await self._create_compressed_backup(backup_path)
else:
await self._create_regular_backup(backup_path)
# Get backup info
backup_info = await self._get_backup_info(backup_path)
# Log the backup
self.logger.info(f"Backup created: {backup_filename} ({backup_info['size_mb']:.1f}MB)")
return {
"success": True,
"backup_path": str(backup_path),
"backup_filename": backup_filename,
"backup_type": backup_type,
"timestamp": timestamp,
"compressed": compress,
"size_mb": backup_info["size_mb"]
}
except Exception as e:
self.logger.error(f"Backup creation failed: {str(e)}")
return {"success": False, "error": str(e)}
async def _create_regular_backup(self, backup_path: Path):
"""Create a regular SQLite backup using the backup API."""
def backup_db():
# Use SQLite backup API for consistent backup
source_conn = sqlite3.connect(self.db_path)
backup_conn = sqlite3.connect(str(backup_path))
# Perform the backup
source_conn.backup(backup_conn)
# Close connections
source_conn.close()
backup_conn.close()
# Run the backup in a thread to avoid blocking
await asyncio.get_event_loop().run_in_executor(None, backup_db)
async def _create_compressed_backup(self, backup_path: Path):
"""Create a compressed backup."""
def backup_and_compress():
# First create temporary uncompressed backup
temp_backup = backup_path.with_suffix('.tmp')
# Use SQLite backup API
source_conn = sqlite3.connect(self.db_path)
backup_conn = sqlite3.connect(str(temp_backup))
source_conn.backup(backup_conn)
source_conn.close()
backup_conn.close()
# Compress the backup
with open(temp_backup, 'rb') as f_in:
with gzip.open(backup_path, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
# Remove temporary file
temp_backup.unlink()
await asyncio.get_event_loop().run_in_executor(None, backup_and_compress)
async def _get_backup_info(self, backup_path: Path) -> Dict:
"""Get information about a backup file."""
stat = backup_path.stat()
return {
"size_bytes": stat.st_size,
"size_mb": stat.st_size / (1024 * 1024),
"created_at": datetime.fromtimestamp(stat.st_mtime),
"compressed": backup_path.suffix == '.gz'
}
async def list_backups(self) -> List[Dict]:
"""List all available backups with metadata."""
backups = []
for backup_file in self.backup_dir.glob("petbot_backup_*.db*"):
try:
info = await self._get_backup_info(backup_file)
# Parse backup filename for metadata
filename = backup_file.name
parts = filename.replace('.gz', '').replace('.db', '').split('_')
if len(parts) >= 4:
backup_type = parts[2]
timestamp = parts[3]
backups.append({
"filename": filename,
"path": str(backup_file),
"type": backup_type,
"timestamp": timestamp,
"created_at": info["created_at"],
"size_mb": info["size_mb"],
"compressed": info["compressed"]
})
except Exception as e:
self.logger.warning(f"Error reading backup {backup_file}: {e}")
continue
# Sort by creation time (newest first)
backups.sort(key=lambda x: x["created_at"], reverse=True)
return backups
async def restore_backup(self, backup_filename: str, target_path: str = None) -> Dict:
"""Restore a database from backup."""
try:
backup_path = self.backup_dir / backup_filename
if not backup_path.exists():
return {"success": False, "error": "Backup file not found"}
target_path = target_path or self.db_path
# Create backup of current database before restore
current_backup = await self.create_backup("pre_restore", compress=True)
if not current_backup["success"]:
return {"success": False, "error": "Failed to backup current database"}
# Restore the backup
if backup_path.suffix == '.gz':
await self._restore_compressed_backup(backup_path, target_path)
else:
await self._restore_regular_backup(backup_path, target_path)
# Verify the restored database
verification = await self._verify_database(target_path)
if not verification["success"]:
return {"success": False, "error": f"Restored database verification failed: {verification['error']}"}
self.logger.info(f"Database restored from backup: {backup_filename}")
return {
"success": True,
"backup_filename": backup_filename,
"target_path": target_path,
"current_backup": current_backup["backup_filename"],
"tables_verified": verification["table_count"]
}
except Exception as e:
self.logger.error(f"Restore failed: {str(e)}")
return {"success": False, "error": str(e)}
async def _restore_regular_backup(self, backup_path: Path, target_path: str):
"""Restore from regular backup."""
def restore():
shutil.copy2(backup_path, target_path)
await asyncio.get_event_loop().run_in_executor(None, restore)
async def _restore_compressed_backup(self, backup_path: Path, target_path: str):
"""Restore from compressed backup."""
def restore():
with gzip.open(backup_path, 'rb') as f_in:
with open(target_path, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
await asyncio.get_event_loop().run_in_executor(None, restore)
async def _verify_database(self, db_path: str) -> Dict:
"""Verify database integrity and structure."""
try:
async with aiosqlite.connect(db_path) as db:
# Check database integrity
cursor = await db.execute("PRAGMA integrity_check")
integrity_result = await cursor.fetchone()
if integrity_result[0] != "ok":
return {"success": False, "error": f"Database integrity check failed: {integrity_result[0]}"}
# Count tables
cursor = await db.execute("SELECT COUNT(*) FROM sqlite_master WHERE type='table'")
table_count = (await cursor.fetchone())[0]
# Basic table existence check
required_tables = ["players", "pets", "pet_species", "moves", "items"]
for table in required_tables:
cursor = await db.execute(f"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", (table,))
exists = (await cursor.fetchone())[0]
if not exists:
return {"success": False, "error": f"Required table '{table}' not found"}
return {"success": True, "table_count": table_count}
except Exception as e:
return {"success": False, "error": str(e)}
async def cleanup_old_backups(self) -> Dict:
"""Remove old backups based on retention policy."""
try:
backups = await self.list_backups()
# Group backups by type
daily_backups = [b for b in backups if b["type"] in ["daily", "manual"]]
weekly_backups = [b for b in backups if b["type"] == "weekly"]
monthly_backups = [b for b in backups if b["type"] == "monthly"]
cleaned_count = 0
# Clean daily backups (keep most recent)
if len(daily_backups) > self.max_daily_backups:
old_daily = daily_backups[self.max_daily_backups:]
for backup in old_daily:
await self._remove_backup(backup["path"])
cleaned_count += 1
# Clean weekly backups
if len(weekly_backups) > self.max_weekly_backups:
old_weekly = weekly_backups[self.max_weekly_backups:]
for backup in old_weekly:
await self._remove_backup(backup["path"])
cleaned_count += 1
# Clean monthly backups
if len(monthly_backups) > self.max_monthly_backups:
old_monthly = monthly_backups[self.max_monthly_backups:]
for backup in old_monthly:
await self._remove_backup(backup["path"])
cleaned_count += 1
self.logger.info(f"Cleaned up {cleaned_count} old backups")
return {
"success": True,
"cleaned_count": cleaned_count,
"remaining_backups": len(backups) - cleaned_count
}
except Exception as e:
self.logger.error(f"Cleanup failed: {str(e)}")
return {"success": False, "error": str(e)}
async def _remove_backup(self, backup_path: str):
"""Remove a backup file."""
try:
Path(backup_path).unlink()
self.logger.debug(f"Removed backup: {backup_path}")
except Exception as e:
self.logger.warning(f"Failed to remove backup {backup_path}: {e}")
async def get_backup_stats(self) -> Dict:
"""Get statistics about backups."""
try:
backups = await self.list_backups()
if not backups:
return {
"success": True,
"total_backups": 0,
"total_size_mb": 0,
"oldest_backup": None,
"newest_backup": None,
"by_type": {}
}
total_size = sum(b["size_mb"] for b in backups)
oldest = min(backups, key=lambda x: x["created_at"])
newest = max(backups, key=lambda x: x["created_at"])
# Group by type
by_type = {}
for backup in backups:
backup_type = backup["type"]
if backup_type not in by_type:
by_type[backup_type] = {"count": 0, "size_mb": 0}
by_type[backup_type]["count"] += 1
by_type[backup_type]["size_mb"] += backup["size_mb"]
return {
"success": True,
"total_backups": len(backups),
"total_size_mb": round(total_size, 1),
"oldest_backup": oldest["created_at"],
"newest_backup": newest["created_at"],
"by_type": by_type
}
except Exception as e:
self.logger.error(f"Failed to get backup stats: {str(e)}")
return {"success": False, "error": str(e)}
async def export_database_structure(self) -> Dict:
"""Export database schema for documentation/analysis."""
try:
structure = {
"export_time": datetime.now().isoformat(),
"database_path": self.db_path,
"tables": {}
}
async with aiosqlite.connect(self.db_path) as db:
# Get all tables
cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = await cursor.fetchall()
for table_name in [t[0] for t in tables]:
# Get table info
cursor = await db.execute(f"PRAGMA table_info({table_name})")
columns = await cursor.fetchall()
# Get row count
cursor = await db.execute(f"SELECT COUNT(*) FROM {table_name}")
row_count = (await cursor.fetchone())[0]
structure["tables"][table_name] = {
"columns": [
{
"name": col[1],
"type": col[2],
"not_null": bool(col[3]),
"default": col[4],
"primary_key": bool(col[5])
}
for col in columns
],
"row_count": row_count
}
# Save structure to file
structure_path = self.backup_dir / f"database_structure_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(structure_path, 'w') as f:
json.dump(structure, f, indent=2, default=str)
return {
"success": True,
"structure_path": str(structure_path),
"table_count": len(structure["tables"]),
"total_rows": sum(table["row_count"] for table in structure["tables"].values())
}
except Exception as e:
self.logger.error(f"Failed to export database structure: {str(e)}")
return {"success": False, "error": str(e)}
# Scheduler for automated backups
class BackupScheduler:
def __init__(self, backup_manager: BackupManager):
self.backup_manager = backup_manager
self.logger = logging.getLogger(__name__)
self.running = False
async def start_scheduler(self):
"""Start the backup scheduler."""
self.running = True
self.logger.info("Backup scheduler started")
while self.running:
try:
await self._check_and_create_backups()
await asyncio.sleep(3600) # Check every hour
except Exception as e:
self.logger.error(f"Scheduler error: {str(e)}")
await asyncio.sleep(3600) # Wait before retrying
def stop_scheduler(self):
"""Stop the backup scheduler."""
self.running = False
self.logger.info("Backup scheduler stopped")
async def _check_and_create_backups(self):
"""Check if backups are needed and create them."""
now = datetime.now()
# Check daily backup (every 24 hours)
if await self._should_create_backup("daily", hours=24):
result = await self.backup_manager.create_backup("daily", compress=True)
if result["success"]:
self.logger.info(f"Daily backup created: {result['backup_filename']}")
# Check weekly backup (every 7 days)
if await self._should_create_backup("weekly", days=7):
result = await self.backup_manager.create_backup("weekly", compress=True)
if result["success"]:
self.logger.info(f"Weekly backup created: {result['backup_filename']}")
# Check monthly backup (every 30 days)
if await self._should_create_backup("monthly", days=30):
result = await self.backup_manager.create_backup("monthly", compress=True)
if result["success"]:
self.logger.info(f"Monthly backup created: {result['backup_filename']}")
# Cleanup old backups
await self.backup_manager.cleanup_old_backups()
async def _should_create_backup(self, backup_type: str, hours: int = 0, days: int = 0) -> bool:
"""Check if a backup of the specified type should be created."""
try:
backups = await self.backup_manager.list_backups()
# Find most recent backup of this type
type_backups = [b for b in backups if b["type"] == backup_type]
if not type_backups:
return True # No backups of this type exist
most_recent = max(type_backups, key=lambda x: x["created_at"])
time_since = datetime.now() - most_recent["created_at"]
required_delta = timedelta(hours=hours, days=days)
return time_since >= required_delta
except Exception as e:
self.logger.error(f"Error checking backup schedule: {str(e)}")
return False

View file

@ -365,6 +365,9 @@ class GameEngine:
except FileNotFoundError:
return None
# Get global spawn multiplier from config
global_multiplier = items_data.get("_config", {}).get("global_spawn_multiplier", 1.0)
# Get all possible items for this location
available_items = []
location_name = location["name"].lower().replace(" ", "_")
@ -375,22 +378,25 @@ class GameEngine:
if "locations" in item:
item_locations = item["locations"]
if "all" in item_locations or location_name in item_locations:
available_items.append(item)
# Apply global multiplier to spawn rate
item_copy = item.copy()
item_copy["effective_spawn_rate"] = item.get("spawn_rate", 0.1) * global_multiplier
available_items.append(item_copy)
if not available_items:
return None
# Calculate total spawn rates for this location
total_rate = sum(item.get("spawn_rate", 0.1) for item in available_items)
# Calculate total spawn rates for this location (using effective rates)
total_rate = sum(item.get("effective_spawn_rate", 0.1) for item in available_items)
# 30% base chance of finding an item
if random.random() > 0.3:
return None
# Choose item based on spawn rates
# Choose item based on effective spawn rates (with global multiplier applied)
chosen_item = random.choices(
available_items,
weights=[item.get("spawn_rate", 0.1) for item in available_items]
weights=[item.get("effective_spawn_rate", 0.1) for item in available_items]
)[0]
# Add item to player's inventory

View file

@ -0,0 +1,395 @@
import asyncio
import socket
import time
import logging
import random
from enum import Enum
from typing import Optional, Callable, Dict, Any
from datetime import datetime, timedelta
class ConnectionState(Enum):
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
AUTHENTICATED = "authenticated"
JOINED = "joined"
RECONNECTING = "reconnecting"
FAILED = "failed"
class IRCConnectionManager:
"""
Robust IRC connection manager with automatic reconnection,
health monitoring, and exponential backoff.
"""
def __init__(self, config: Dict[str, Any], bot_instance=None):
self.config = config
self.bot = bot_instance
self.socket = None
self.state = ConnectionState.DISCONNECTED
self.running = False
# Connection monitoring
self.last_ping_time = 0
self.last_pong_time = 0
self.ping_interval = 60 # Send PING every 60 seconds
self.ping_timeout = 120 # Expect PONG within 2 minutes
# Reconnection settings
self.reconnect_attempts = 0
self.max_reconnect_attempts = 50
self.base_reconnect_delay = 1 # Start with 1 second
self.max_reconnect_delay = 300 # Cap at 5 minutes
self.reconnect_jitter = 0.1 # 10% jitter
# Connection tracking
self.connection_start_time = None
self.last_successful_connection = None
self.total_reconnections = 0
self.connection_failures = 0
# Event callbacks
self.on_connect_callback = None
self.on_disconnect_callback = None
self.on_message_callback = None
self.on_connection_lost_callback = None
# Health monitoring
self.health_check_interval = 30 # Check health every 30 seconds
self.health_check_task = None
self.message_count = 0
self.last_message_time = 0
# Setup logging
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
# Create console handler if none exists
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def set_callbacks(self, on_connect=None, on_disconnect=None, on_message=None, on_connection_lost=None):
"""Set callback functions for connection events."""
self.on_connect_callback = on_connect
self.on_disconnect_callback = on_disconnect
self.on_message_callback = on_message
self.on_connection_lost_callback = on_connection_lost
async def start(self):
"""Start the connection manager."""
if self.running:
self.logger.warning("Connection manager is already running")
return
self.running = True
self.logger.info("Starting IRC connection manager")
# Start health monitoring
self.health_check_task = asyncio.create_task(self._health_monitor())
# Start connection loop
await self._connection_loop()
async def stop(self):
"""Stop the connection manager and close connections."""
self.running = False
self.logger.info("Stopping IRC connection manager")
# Cancel health monitoring
if self.health_check_task:
self.health_check_task.cancel()
try:
await self.health_check_task
except asyncio.CancelledError:
pass
# Close socket
await self._disconnect()
async def _connection_loop(self):
"""Main connection loop with automatic reconnection."""
while self.running:
try:
if self.state == ConnectionState.DISCONNECTED:
await self._connect()
if self.state in [ConnectionState.CONNECTED, ConnectionState.AUTHENTICATED, ConnectionState.JOINED]:
await self._handle_messages()
await asyncio.sleep(0.1)
except Exception as e:
self.logger.error(f"Error in connection loop: {e}")
await self._handle_connection_error(e)
async def _connect(self):
"""Connect to IRC server with retry logic."""
if self.reconnect_attempts >= self.max_reconnect_attempts:
self.logger.error(f"Maximum reconnection attempts ({self.max_reconnect_attempts}) reached")
self.state = ConnectionState.FAILED
return
self.state = ConnectionState.CONNECTING
self.connection_start_time = datetime.now()
try:
# Calculate reconnection delay with exponential backoff
if self.reconnect_attempts > 0:
delay = min(
self.base_reconnect_delay * (2 ** self.reconnect_attempts),
self.max_reconnect_delay
)
# Add jitter to prevent thundering herd
jitter = delay * self.reconnect_jitter * random.random()
delay += jitter
self.logger.info(f"Reconnection attempt {self.reconnect_attempts + 1}/{self.max_reconnect_attempts} after {delay:.1f}s delay")
await asyncio.sleep(delay)
# Test basic connectivity first
await self._test_connectivity()
# Create socket connection
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.settimeout(10) # 10 second timeout for connection
# Connect to server
await asyncio.get_event_loop().run_in_executor(
None, self.socket.connect, (self.config["server"], self.config["port"])
)
self.socket.settimeout(1) # Shorter timeout for message handling
self.state = ConnectionState.CONNECTED
# Send IRC handshake
await self._send_handshake()
# Reset reconnection counter on successful connection
self.reconnect_attempts = 0
self.last_successful_connection = datetime.now()
self.total_reconnections += 1
self.logger.info(f"Successfully connected to {self.config['server']}:{self.config['port']}")
except Exception as e:
self.logger.error(f"Connection failed: {e}")
self.connection_failures += 1
self.reconnect_attempts += 1
await self._disconnect()
if self.on_connection_lost_callback:
await self.on_connection_lost_callback(e)
async def _test_connectivity(self):
"""Test basic network connectivity to IRC server."""
try:
test_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
test_sock.settimeout(5)
await asyncio.get_event_loop().run_in_executor(
None, test_sock.connect, (self.config["server"], self.config["port"])
)
test_sock.close()
except Exception as e:
raise ConnectionError(f"Network connectivity test failed: {e}")
async def _send_handshake(self):
"""Send IRC handshake messages."""
nickname = self.config["nickname"]
await self._send_raw(f"NICK {nickname}")
await self._send_raw(f"USER {nickname} 0 * :{nickname}")
# Initialize ping tracking
self.last_ping_time = time.time()
self.last_pong_time = time.time()
async def _handle_messages(self):
"""Handle incoming IRC messages."""
try:
data = await asyncio.get_event_loop().run_in_executor(
None, self.socket.recv, 4096
)
if not data:
raise ConnectionError("Connection closed by server")
# Update message tracking
self.message_count += 1
self.last_message_time = time.time()
# Decode and process messages
lines = data.decode('utf-8', errors='ignore').strip().split('\n')
for line in lines:
if line.strip():
await self._process_line(line.strip())
except socket.timeout:
# Timeout is expected, continue
pass
except Exception as e:
raise ConnectionError(f"Message handling error: {e}")
async def _process_line(self, line):
"""Process a single IRC line."""
# Handle PING/PONG
if line.startswith("PING"):
pong_response = line.replace("PING", "PONG")
await self._send_raw(pong_response)
return
if line.startswith("PONG"):
self.last_pong_time = time.time()
return
# Handle connection completion
if "376" in line or "422" in line: # End of MOTD
if self.state == ConnectionState.CONNECTED:
self.state = ConnectionState.AUTHENTICATED
await self._send_raw(f"JOIN {self.config['channel']}")
# Handle successful channel join
if " JOIN " in line and self.config["channel"] in line:
if self.state == ConnectionState.AUTHENTICATED:
self.state = ConnectionState.JOINED
self.logger.info(f"Successfully joined {self.config['channel']}")
if self.on_connect_callback:
await self.on_connect_callback()
# Handle nickname conflicts
if "433" in line: # Nickname in use
new_nickname = f"{self.config['nickname']}_"
self.logger.warning(f"Nickname conflict, trying {new_nickname}")
await self._send_raw(f"NICK {new_nickname}")
# Handle disconnection
if "ERROR :Closing Link" in line:
raise ConnectionError("Server closed connection")
# Forward message to callback
if self.on_message_callback:
await self.on_message_callback(line)
async def _send_raw(self, message):
"""Send raw IRC message."""
if not self.socket:
raise ConnectionError("Not connected to IRC server")
try:
full_message = f"{message}\r\n"
await asyncio.get_event_loop().run_in_executor(
None, self.socket.send, full_message.encode('utf-8')
)
except Exception as e:
raise ConnectionError(f"Failed to send message: {e}")
async def send_message(self, target, message):
"""Send a message to a channel or user."""
if self.state != ConnectionState.JOINED:
self.logger.warning(f"Cannot send message, not joined to channel (state: {self.state})")
return False
try:
await self._send_raw(f"PRIVMSG {target} :{message}")
return True
except Exception as e:
self.logger.error(f"Failed to send message to {target}: {e}")
return False
async def _health_monitor(self):
"""Monitor connection health and send periodic pings."""
while self.running:
try:
await asyncio.sleep(self.health_check_interval)
if self.state == ConnectionState.JOINED:
await self._check_connection_health()
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Health monitor error: {e}")
async def _check_connection_health(self):
"""Check if connection is healthy and send pings as needed."""
current_time = time.time()
# Send ping if interval has passed
if current_time - self.last_ping_time > self.ping_interval:
try:
await self._send_raw(f"PING :health_check_{int(current_time)}")
self.last_ping_time = current_time
except Exception as e:
self.logger.error(f"Failed to send ping: {e}")
raise ConnectionError("Health check ping failed")
# Check if we've received a pong recently
if current_time - self.last_pong_time > self.ping_timeout:
self.logger.warning("No PONG received within timeout period")
raise ConnectionError("Ping timeout - connection appears dead")
async def _handle_connection_error(self, error):
"""Handle connection errors and initiate reconnection."""
self.logger.error(f"Connection error: {error}")
# Notify callback
if self.on_disconnect_callback:
await self.on_disconnect_callback(error)
# Disconnect and prepare for reconnection
await self._disconnect()
# Set state for reconnection
if self.running:
self.state = ConnectionState.DISCONNECTED
self.reconnect_attempts += 1
async def _disconnect(self):
"""Disconnect from IRC server."""
if self.socket:
try:
self.socket.close()
except:
pass
self.socket = None
old_state = self.state
self.state = ConnectionState.DISCONNECTED
if old_state != ConnectionState.DISCONNECTED:
self.logger.info("Disconnected from IRC server")
def get_connection_stats(self) -> Dict[str, Any]:
"""Get connection statistics."""
uptime = None
if self.connection_start_time:
uptime = datetime.now() - self.connection_start_time
return {
"state": self.state.value,
"connected": self.state == ConnectionState.JOINED,
"uptime": str(uptime) if uptime else None,
"reconnect_attempts": self.reconnect_attempts,
"total_reconnections": self.total_reconnections,
"connection_failures": self.connection_failures,
"last_successful_connection": self.last_successful_connection,
"message_count": self.message_count,
"last_message_time": datetime.fromtimestamp(self.last_message_time) if self.last_message_time else None,
"last_ping_time": datetime.fromtimestamp(self.last_ping_time) if self.last_ping_time else None,
"last_pong_time": datetime.fromtimestamp(self.last_pong_time) if self.last_pong_time else None
}
def is_connected(self) -> bool:
"""Check if bot is connected and ready."""
return self.state == ConnectionState.JOINED
def get_state(self) -> ConnectionState:
"""Get current connection state."""
return self.state

426
src/rate_limiter.py Normal file
View file

@ -0,0 +1,426 @@
import time
import asyncio
from typing import Dict, Optional, Tuple
from datetime import datetime, timedelta
from enum import Enum
import logging
class CommandCategory(Enum):
"""Categories of commands with different rate limits."""
BASIC = "basic" # !help, !ping, !status
GAMEPLAY = "gameplay" # !explore, !catch, !battle
MANAGEMENT = "management" # !pets, !activate, !deactivate
ADMIN = "admin" # !backup, !reload, !reconnect
WEB = "web" # Web interface requests
class RateLimiter:
"""
Token bucket rate limiter with per-user tracking and command categories.
Features:
- Token bucket algorithm for smooth rate limiting
- Per-user rate tracking
- Different limits for different command categories
- Burst capacity handling
- Admin exemption
- Detailed logging and monitoring
"""
def __init__(self, config: Optional[Dict] = None):
self.logger = logging.getLogger(__name__)
# Default rate limit configuration
self.config = {
"enabled": True,
"categories": {
CommandCategory.BASIC: {
"requests_per_minute": 20,
"burst_capacity": 5,
"cooldown_seconds": 1
},
CommandCategory.GAMEPLAY: {
"requests_per_minute": 10,
"burst_capacity": 3,
"cooldown_seconds": 3
},
CommandCategory.MANAGEMENT: {
"requests_per_minute": 5,
"burst_capacity": 2,
"cooldown_seconds": 5
},
CommandCategory.ADMIN: {
"requests_per_minute": 100,
"burst_capacity": 10,
"cooldown_seconds": 0
},
CommandCategory.WEB: {
"requests_per_minute": 60,
"burst_capacity": 10,
"cooldown_seconds": 1
}
},
"admin_users": ["megasconed"], # This will be overridden by bot initialization
"global_limits": {
"max_requests_per_minute": 200,
"max_concurrent_users": 100
},
"violation_penalties": {
"warning_threshold": 3,
"temporary_ban_threshold": 10,
"temporary_ban_duration": 300 # 5 minutes
}
}
# Override with provided config
if config:
self._update_config(config)
# Rate limiting state
self.user_buckets: Dict[str, Dict] = {}
self.global_stats = {
"requests_this_minute": 0,
"minute_start": time.time(),
"active_users": set(),
"total_requests": 0,
"blocked_requests": 0
}
# Violation tracking
self.violations: Dict[str, Dict] = {}
self.banned_users: Dict[str, float] = {} # user -> ban_end_time
# Background cleanup task
self.cleanup_task = None
self.start_cleanup_task()
def _update_config(self, config: Dict):
"""Update configuration with provided values."""
def deep_update(base_dict, update_dict):
for key, value in update_dict.items():
if key in base_dict and isinstance(base_dict[key], dict) and isinstance(value, dict):
deep_update(base_dict[key], value)
else:
base_dict[key] = value
deep_update(self.config, config)
def start_cleanup_task(self):
"""Start background cleanup task."""
if self.cleanup_task is None or self.cleanup_task.done():
self.cleanup_task = asyncio.create_task(self._cleanup_loop())
async def _cleanup_loop(self):
"""Background task to clean up old data."""
while True:
try:
await asyncio.sleep(60) # Cleanup every minute
await self._cleanup_old_data()
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Error in rate limiter cleanup: {e}")
async def _cleanup_old_data(self):
"""Clean up old rate limiting data."""
current_time = time.time()
# Clean up old user buckets (inactive for 10 minutes)
inactive_threshold = current_time - 600
inactive_users = [
user for user, data in self.user_buckets.items()
if data.get("last_request", 0) < inactive_threshold
]
for user in inactive_users:
del self.user_buckets[user]
# Clean up old violations (older than 1 hour)
violation_threshold = current_time - 3600
old_violations = [
user for user, data in self.violations.items()
if data.get("last_violation", 0) < violation_threshold
]
for user in old_violations:
del self.violations[user]
# Clean up expired bans
expired_bans = [
user for user, ban_end in self.banned_users.items()
if current_time > ban_end
]
for user in expired_bans:
del self.banned_users[user]
self.logger.info(f"Temporary ban expired for user: {user}")
# Reset global stats every minute
if current_time - self.global_stats["minute_start"] >= 60:
self.global_stats["requests_this_minute"] = 0
self.global_stats["minute_start"] = current_time
self.global_stats["active_users"].clear()
async def check_rate_limit(self, user: str, category: CommandCategory,
command: str = None) -> Tuple[bool, Optional[str]]:
"""
Check if a request is allowed under rate limiting.
Returns:
(allowed: bool, message: Optional[str])
"""
if not self.config["enabled"]:
return True, None
current_time = time.time()
# Check if user is temporarily banned
if user in self.banned_users:
if current_time < self.banned_users[user]:
remaining = int(self.banned_users[user] - current_time)
return False, f"⛔ You are temporarily banned for {remaining} seconds due to rate limit violations."
else:
del self.banned_users[user]
# Admin exemption
if user.lower() in [admin.lower() for admin in self.config["admin_users"]]:
return True, None
# Check global limits
if not await self._check_global_limits():
return False, "🚫 Server is currently overloaded. Please try again later."
# Check per-user rate limit
allowed, message = await self._check_user_rate_limit(user, category, current_time)
if allowed:
# Update global stats
self.global_stats["requests_this_minute"] += 1
self.global_stats["total_requests"] += 1
self.global_stats["active_users"].add(user)
else:
# Track violation
await self._track_violation(user, category, command)
self.global_stats["blocked_requests"] += 1
return allowed, message
async def _check_global_limits(self) -> bool:
"""Check global rate limits."""
# Check requests per minute
if (self.global_stats["requests_this_minute"] >=
self.config["global_limits"]["max_requests_per_minute"]):
return False
# Check concurrent users
if (len(self.global_stats["active_users"]) >=
self.config["global_limits"]["max_concurrent_users"]):
return False
return True
async def _check_user_rate_limit(self, user: str, category: CommandCategory,
current_time: float) -> Tuple[bool, Optional[str]]:
"""Check per-user rate limit using token bucket algorithm."""
category_config = self.config["categories"][category]
# Get or create user bucket
if user not in self.user_buckets:
self.user_buckets[user] = {
"tokens": category_config["burst_capacity"],
"last_refill": current_time,
"last_request": current_time
}
bucket = self.user_buckets[user]
# Calculate tokens to add (refill rate)
time_passed = current_time - bucket["last_refill"]
tokens_per_second = category_config["requests_per_minute"] / 60.0
tokens_to_add = time_passed * tokens_per_second
# Refill bucket (up to burst capacity)
bucket["tokens"] = min(
category_config["burst_capacity"],
bucket["tokens"] + tokens_to_add
)
bucket["last_refill"] = current_time
# Check if request is allowed
if bucket["tokens"] >= 1:
bucket["tokens"] -= 1
bucket["last_request"] = current_time
return True, None
else:
# Calculate cooldown time
time_since_last = current_time - bucket["last_request"]
cooldown = category_config["cooldown_seconds"]
if time_since_last < cooldown:
remaining = int(cooldown - time_since_last)
return False, f"⏱️ Rate limit exceeded. Please wait {remaining} seconds before using {category.value} commands."
else:
# Allow if cooldown has passed
bucket["tokens"] = category_config["burst_capacity"] - 1
bucket["last_request"] = current_time
return True, None
async def _track_violation(self, user: str, category: CommandCategory, command: str):
"""Track rate limit violations for potential penalties."""
current_time = time.time()
if user not in self.violations:
self.violations[user] = {
"count": 0,
"last_violation": current_time,
"categories": {}
}
violation = self.violations[user]
violation["count"] += 1
violation["last_violation"] = current_time
# Track by category
if category.value not in violation["categories"]:
violation["categories"][category.value] = 0
violation["categories"][category.value] += 1
# Check for penalties
penalty_config = self.config["violation_penalties"]
if violation["count"] >= penalty_config["temporary_ban_threshold"]:
# Temporary ban
ban_duration = penalty_config["temporary_ban_duration"]
self.banned_users[user] = current_time + ban_duration
self.logger.warning(f"User {user} temporarily banned for {ban_duration}s due to rate limit violations")
elif violation["count"] >= penalty_config["warning_threshold"]:
# Warning threshold reached
self.logger.warning(f"User {user} reached rate limit warning threshold ({violation['count']} violations)")
def get_user_stats(self, user: str) -> Dict:
"""Get rate limiting stats for a specific user."""
stats = {
"user": user,
"is_banned": user in self.banned_users,
"ban_expires": None,
"violations": 0,
"buckets": {},
"admin_exemption": user.lower() in [admin.lower() for admin in self.config["admin_users"]]
}
# Ban info
if stats["is_banned"]:
stats["ban_expires"] = datetime.fromtimestamp(self.banned_users[user])
# Violation info
if user in self.violations:
stats["violations"] = self.violations[user]["count"]
# Bucket info
if user in self.user_buckets:
bucket = self.user_buckets[user]
stats["buckets"] = {
"tokens": round(bucket["tokens"], 2),
"last_request": datetime.fromtimestamp(bucket["last_request"])
}
return stats
def get_global_stats(self) -> Dict:
"""Get global rate limiting statistics."""
return {
"enabled": self.config["enabled"],
"requests_this_minute": self.global_stats["requests_this_minute"],
"active_users": len(self.global_stats["active_users"]),
"total_requests": self.global_stats["total_requests"],
"blocked_requests": self.global_stats["blocked_requests"],
"banned_users": len(self.banned_users),
"tracked_users": len(self.user_buckets),
"total_violations": sum(v["count"] for v in self.violations.values()),
"config": self.config
}
def is_user_banned(self, user: str) -> bool:
"""Check if a user is currently banned."""
if user not in self.banned_users:
return False
if time.time() > self.banned_users[user]:
del self.banned_users[user]
return False
return True
def unban_user(self, user: str) -> bool:
"""Manually unban a user (admin function)."""
if user in self.banned_users:
del self.banned_users[user]
self.logger.info(f"User {user} manually unbanned")
return True
return False
def reset_user_violations(self, user: str) -> bool:
"""Reset violations for a user (admin function)."""
if user in self.violations:
del self.violations[user]
self.logger.info(f"Violations reset for user {user}")
return True
return False
async def shutdown(self):
"""Shutdown the rate limiter and cleanup tasks."""
if self.cleanup_task:
self.cleanup_task.cancel()
try:
await self.cleanup_task
except asyncio.CancelledError:
pass
self.logger.info("Rate limiter shutdown complete")
# Command category mapping
COMMAND_CATEGORIES = {
# Basic commands
"help": CommandCategory.BASIC,
"ping": CommandCategory.BASIC,
"status": CommandCategory.BASIC,
"uptime": CommandCategory.BASIC,
"connection_stats": CommandCategory.BASIC,
# Gameplay commands
"start": CommandCategory.GAMEPLAY,
"explore": CommandCategory.GAMEPLAY,
"catch": CommandCategory.GAMEPLAY,
"battle": CommandCategory.GAMEPLAY,
"attack": CommandCategory.GAMEPLAY,
"moves": CommandCategory.GAMEPLAY,
"flee": CommandCategory.GAMEPLAY,
"travel": CommandCategory.GAMEPLAY,
"weather": CommandCategory.GAMEPLAY,
"gym": CommandCategory.GAMEPLAY,
# Management commands
"pets": CommandCategory.MANAGEMENT,
"activate": CommandCategory.MANAGEMENT,
"deactivate": CommandCategory.MANAGEMENT,
"stats": CommandCategory.MANAGEMENT,
"inventory": CommandCategory.MANAGEMENT,
"use": CommandCategory.MANAGEMENT,
"nickname": CommandCategory.MANAGEMENT,
# Admin commands
"backup": CommandCategory.ADMIN,
"restore": CommandCategory.ADMIN,
"backups": CommandCategory.ADMIN,
"backup_stats": CommandCategory.ADMIN,
"backup_cleanup": CommandCategory.ADMIN,
"reload": CommandCategory.ADMIN,
"reconnect": CommandCategory.ADMIN,
}
def get_command_category(command: str) -> CommandCategory:
"""Get the rate limiting category for a command."""
return COMMAND_CATEGORIES.get(command.lower(), CommandCategory.GAMEPLAY)