subreddit-announcer/bot.py

475 lines
18 KiB
Python

#!/usr/bin/env python3
"""subreddit-announcer — a Discord bot that announces new Reddit posts.
A long-running discord.py bot. It polls the `/new` listing of each watched
subreddit on a fixed interval and posts an embed to the mapped Discord channel
for every submission it hasn't seen before.
State lives in a local SQLite DB:
- `watches` : subreddit -> (guild, channel) mapping + bootstrap flag
- `seen_posts` : every submission id we've already handled, for dedup
When a subreddit is first watched it is *bootstrapped silently* — the posts
currently on /new are marked seen without announcing, so adding a busy
subreddit doesn't dump its backlog into the channel. Only posts created after
that point are announced.
Subreddits are managed at runtime with slash commands:
/watch <subreddit> [channel] start announcing a subreddit's new posts
/unwatch <subreddit> stop announcing it
/watching list active watches in this server
/setinterval <seconds> change the global poll interval
Configuration comes from environment variables (see config.example.env).
Reddit access is read-only via a registered app's client id/secret — no
Reddit account login is required.
"""
from __future__ import annotations
import datetime
import logging
import os
import sqlite3
import time
import asyncpraw
import discord
from discord import app_commands
from discord.ext import tasks
# --------------------------------------------------------------------------
# Configuration (env-overridable; secrets stay in config.env, never in git)
# --------------------------------------------------------------------------
DISCORD_TOKEN = os.environ.get("DISCORD_TOKEN", "")
REDDIT_CLIENT_ID = os.environ.get("REDDIT_CLIENT_ID", "")
REDDIT_CLIENT_SECRET = os.environ.get("REDDIT_CLIENT_SECRET", "")
REDDIT_USER_AGENT = os.environ.get(
"REDDIT_USER_AGENT", "discord:subreddit-announcer:1.0 (by /u/megaproxy)"
)
DB_PATH = os.environ.get(
"ANNOUNCER_DB", os.path.join(os.path.dirname(__file__), "announcer.db")
)
# Seconds between poll sweeps across all watched subreddits.
POLL_INTERVAL = float(os.environ.get("ANNOUNCER_POLL_INTERVAL", "60"))
# How many posts to pull from each subreddit's /new per sweep. 25 is plenty
# unless a sub gets >25 posts within one POLL_INTERVAL.
FETCH_LIMIT = int(os.environ.get("ANNOUNCER_FETCH_LIMIT", "25"))
logging.basicConfig(
level=os.environ.get("ANNOUNCER_LOGLEVEL", "INFO"),
format="%(asctime)s %(levelname)s %(message)s",
)
log = logging.getLogger("announcer")
# --------------------------------------------------------------------------
# State storage
# --------------------------------------------------------------------------
class Store:
"""SQLite-backed state: which subreddits go to which channels, and which
submissions have already been announced."""
def __init__(self, path: str):
# check_same_thread=False: discord.py runs the loop in one thread but
# the connection may be touched from task callbacks on the same loop.
self.db = sqlite3.connect(path, check_same_thread=False)
self.db.execute(
"""
CREATE TABLE IF NOT EXISTS watches (
subreddit TEXT NOT NULL,
guild_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL,
added_by INTEGER NOT NULL,
added_at INTEGER NOT NULL,
bootstrapped INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (subreddit, guild_id)
)
"""
)
self.db.execute(
"""
CREATE TABLE IF NOT EXISTS seen_posts (
post_id TEXT NOT NULL,
subreddit TEXT NOT NULL,
seen_at INTEGER NOT NULL,
PRIMARY KEY (post_id, subreddit)
)
"""
)
self.db.commit()
# --- watches ----------------------------------------------------------
def add_watch(self, subreddit: str, guild_id: int, channel_id: int, added_by: int):
"""Insert or re-point a watch. Re-watching resets the bootstrap flag so
the new channel gets a clean silent baseline rather than a backlog."""
self.db.execute(
"""
INSERT INTO watches (subreddit, guild_id, channel_id, added_by, added_at, bootstrapped)
VALUES (?, ?, ?, ?, ?, 0)
ON CONFLICT(subreddit, guild_id) DO UPDATE SET
channel_id = excluded.channel_id,
added_by = excluded.added_by,
added_at = excluded.added_at,
bootstrapped = 0
""",
(subreddit, guild_id, channel_id, added_by, int(time.time())),
)
self.db.commit()
def remove_watch(self, subreddit: str, guild_id: int) -> bool:
cur = self.db.execute(
"DELETE FROM watches WHERE subreddit = ? AND guild_id = ?",
(subreddit, guild_id),
)
self.db.commit()
return cur.rowcount > 0
def list_watches(self, guild_id: int):
return self.db.execute(
"SELECT subreddit, channel_id FROM watches WHERE guild_id = ? ORDER BY subreddit",
(guild_id,),
).fetchall()
def all_watches(self):
"""Every watch across all guilds, for the poll loop."""
return self.db.execute(
"SELECT subreddit, guild_id, channel_id, bootstrapped FROM watches"
).fetchall()
def distinct_subreddits(self):
return [
r[0]
for r in self.db.execute(
"SELECT DISTINCT subreddit FROM watches"
).fetchall()
]
def mark_bootstrapped(self, subreddit: str):
self.db.execute(
"UPDATE watches SET bootstrapped = 1 WHERE subreddit = ?", (subreddit,)
)
self.db.commit()
# --- seen posts -------------------------------------------------------
def is_seen(self, post_id: str, subreddit: str) -> bool:
return (
self.db.execute(
"SELECT 1 FROM seen_posts WHERE post_id = ? AND subreddit = ?",
(post_id, subreddit),
).fetchone()
is not None
)
def mark_seen(self, post_id: str, subreddit: str):
self.db.execute(
"INSERT OR IGNORE INTO seen_posts (post_id, subreddit, seen_at) VALUES (?, ?, ?)",
(post_id, subreddit, int(time.time())),
)
self.db.commit()
def prune_seen(self, subreddit: str, keep: int = 500):
"""Keep the seen table from growing unbounded: retain the most recent
`keep` ids per subreddit. Old posts won't reappear on /new anyway."""
self.db.execute(
"""
DELETE FROM seen_posts
WHERE subreddit = ? AND post_id NOT IN (
SELECT post_id FROM seen_posts
WHERE subreddit = ? ORDER BY seen_at DESC LIMIT ?
)
""",
(subreddit, subreddit, keep),
)
self.db.commit()
# --------------------------------------------------------------------------
# Discord client
# --------------------------------------------------------------------------
class AnnouncerBot(discord.Client):
def __init__(self, store: Store):
# No privileged intents needed: we only post, never read messages.
intents = discord.Intents.default()
super().__init__(intents=intents)
self.store = store
self.tree = app_commands.CommandTree(self)
self.reddit: asyncpraw.Reddit | None = None
self.poll_interval = POLL_INTERVAL
async def setup_hook(self):
self.reddit = asyncpraw.Reddit(
client_id=REDDIT_CLIENT_ID,
client_secret=REDDIT_CLIENT_SECRET,
user_agent=REDDIT_USER_AGENT,
)
self.reddit.read_only = True
await self.tree.sync()
self.poll_loop.change_interval(seconds=self.poll_interval)
self.poll_loop.start()
async def on_ready(self):
log.info("Logged in as %s (id %s)", self.user, self.user.id)
log.info(
"Watching %d subreddit(s) across %d guild(s)",
len(self.store.distinct_subreddits()),
len(self.guilds),
)
async def close(self):
if self.reddit:
await self.reddit.close()
await super().close()
# --- the poll loop ----------------------------------------------------
@tasks.loop(seconds=POLL_INTERVAL)
async def poll_loop(self):
# Group watches by subreddit so each is fetched once per sweep even if
# several channels (or guilds) want it.
watches = self.store.all_watches()
by_sub: dict[str, list[tuple[int, int, int]]] = {}
for subreddit, guild_id, channel_id, bootstrapped in watches:
by_sub.setdefault(subreddit, []).append(
(guild_id, channel_id, bootstrapped)
)
for subreddit, targets in by_sub.items():
try:
await self._poll_subreddit(subreddit, targets)
except Exception:
# One bad subreddit (deleted, private, banned) must not kill
# the whole sweep — log and move on.
log.exception("Error polling r/%s", subreddit)
async def _poll_subreddit(self, subreddit: str, targets: list[tuple[int, int, int]]):
sub = await self.reddit.subreddit(subreddit)
# Newest first from Reddit; reverse so we announce oldest->newest.
posts = [s async for s in sub.new(limit=FETCH_LIMIT)]
posts.reverse()
first_run = any(b == 0 for (_, _, b) in targets)
if first_run:
# Silent baseline: mark everything currently on /new as seen and
# don't announce, so a freshly-added subreddit doesn't dump its
# backlog. Flip the flag for every target row of this subreddit.
for post in posts:
self.store.mark_seen(post.id, subreddit)
self.store.mark_bootstrapped(subreddit)
log.info(
"Bootstrapped r/%s (%d existing posts marked seen)",
subreddit,
len(posts),
)
return
for post in posts:
if self.store.is_seen(post.id, subreddit):
continue
self.store.mark_seen(post.id, subreddit)
embed = self._build_embed(subreddit, post)
for guild_id, channel_id, _ in targets:
channel = self.get_channel(channel_id)
if channel is None:
log.warning(
"Channel %s (guild %s) for r/%s not found; skipping",
channel_id,
guild_id,
subreddit,
)
continue
try:
await channel.send(embed=embed)
except discord.DiscordException:
log.exception(
"Failed to post r/%s/%s to channel %s",
subreddit,
post.id,
channel_id,
)
self.store.prune_seen(subreddit)
@poll_loop.before_loop
async def _before_poll(self):
await self.wait_until_ready()
@staticmethod
def _build_embed(subreddit: str, post) -> discord.Embed:
title = post.title
if len(title) > 256: # Discord embed title hard limit
title = title[:253] + ""
permalink = f"https://reddit.com{post.permalink}"
embed = discord.Embed(
title=title,
url=permalink,
color=0xFF4500, # reddit orange
)
embed.set_author(name=f"r/{subreddit}", url=f"https://reddit.com/r/{subreddit}")
author = getattr(post, "author", None)
embed.add_field(
name="Author", value=f"u/{author}" if author else "[deleted]", inline=True
)
# Self-text preview for text posts; thumbnail/link for the rest.
if getattr(post, "is_self", False) and post.selftext:
body = post.selftext
if len(body) > 500:
body = body[:497] + ""
embed.description = body
elif getattr(post, "url", None):
url = post.url
if any(url.lower().endswith(ext) for ext in (".jpg", ".jpeg", ".png", ".gif", ".webp")):
embed.set_image(url=url)
else:
embed.add_field(name="Link", value=url, inline=False)
embed.set_footer(text="Reddit")
created = getattr(post, "created_utc", None)
if created:
embed.timestamp = datetime.datetime.fromtimestamp(
created, tz=datetime.timezone.utc
)
return embed
# --------------------------------------------------------------------------
# Slash commands
# --------------------------------------------------------------------------
def _clean_subreddit(name: str) -> str:
"""Normalise user input like 'r/Python', '/r/python ', 'Python' -> 'python'."""
name = name.strip().lstrip("/")
if name.lower().startswith("r/"):
name = name[2:]
return name.strip().lower()
def register_commands(bot: AnnouncerBot):
tree = bot.tree
@tree.command(name="watch", description="Announce a subreddit's new posts in a channel.")
@app_commands.describe(
subreddit="Subreddit name, e.g. python or r/python",
channel="Channel to post in (defaults to the current channel)",
)
async def watch(
interaction: discord.Interaction,
subreddit: str,
channel: discord.TextChannel | None = None,
):
if interaction.guild_id is None:
await interaction.response.send_message(
"Use this in a server, not a DM.", ephemeral=True
)
return
name = _clean_subreddit(subreddit)
if not name:
await interaction.response.send_message(
"That doesn't look like a subreddit name.", ephemeral=True
)
return
target = channel or interaction.channel
# Verify the subreddit exists and is readable before committing.
await interaction.response.defer(ephemeral=True)
try:
sub = await bot.reddit.subreddit(name)
await sub.load() # raises if it doesn't exist / is private / banned
except Exception as exc:
log.info("Rejected r/%s: %s", name, exc)
await interaction.followup.send(
f"Couldn't access **r/{name}** — check the name (and that it's public).",
ephemeral=True,
)
return
bot.store.add_watch(name, interaction.guild_id, target.id, interaction.user.id)
await interaction.followup.send(
f"Now watching **r/{name}** → {target.mention}. "
f"Existing posts are skipped; you'll get new ones from here on.",
ephemeral=True,
)
@tree.command(name="unwatch", description="Stop announcing a subreddit in this server.")
@app_commands.describe(subreddit="Subreddit to stop watching")
async def unwatch(interaction: discord.Interaction, subreddit: str):
if interaction.guild_id is None:
await interaction.response.send_message(
"Use this in a server, not a DM.", ephemeral=True
)
return
name = _clean_subreddit(subreddit)
removed = bot.store.remove_watch(name, interaction.guild_id)
if removed:
await interaction.response.send_message(
f"Stopped watching **r/{name}**.", ephemeral=True
)
else:
await interaction.response.send_message(
f"Wasn't watching **r/{name}** here.", ephemeral=True
)
@tree.command(name="watching", description="List the subreddits announced in this server.")
async def watching(interaction: discord.Interaction):
if interaction.guild_id is None:
await interaction.response.send_message(
"Use this in a server, not a DM.", ephemeral=True
)
return
rows = bot.store.list_watches(interaction.guild_id)
if not rows:
await interaction.response.send_message(
"Not watching any subreddits here yet. Add one with `/watch`.",
ephemeral=True,
)
return
lines = [f"• **r/{sub}** → <#{chan}>" for sub, chan in rows]
await interaction.response.send_message("\n".join(lines), ephemeral=True)
@tree.command(name="setinterval", description="Set how often (seconds) new posts are checked.")
@app_commands.describe(seconds="Poll interval in seconds (minimum 15)")
async def setinterval(interaction: discord.Interaction, seconds: int):
if seconds < 15:
await interaction.response.send_message(
"Minimum interval is 15 seconds (Reddit rate limits).", ephemeral=True
)
return
bot.poll_interval = float(seconds)
bot.poll_loop.change_interval(seconds=float(seconds))
await interaction.response.send_message(
f"Poll interval set to {seconds}s.", ephemeral=True
)
# --------------------------------------------------------------------------
# Entry point
# --------------------------------------------------------------------------
def main():
if not DISCORD_TOKEN:
raise SystemExit("DISCORD_TOKEN is not set — see config.example.env")
if not (REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET):
raise SystemExit(
"REDDIT_CLIENT_ID / REDDIT_CLIENT_SECRET not set — see config.example.env"
)
store = Store(DB_PATH)
log.info("State DB: %s", DB_PATH)
bot = AnnouncerBot(store)
register_commands(bot)
bot.run(DISCORD_TOKEN, log_handler=None) # we configure logging ourselves
if __name__ == "__main__":
main()