475 lines
18 KiB
Python
475 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""subreddit-announcer — a Discord bot that announces new Reddit posts.
|
|
|
|
A long-running discord.py bot. It polls the `/new` listing of each watched
|
|
subreddit on a fixed interval and posts an embed to the mapped Discord channel
|
|
for every submission it hasn't seen before.
|
|
|
|
State lives in a local SQLite DB:
|
|
- `watches` : subreddit -> (guild, channel) mapping + bootstrap flag
|
|
- `seen_posts` : every submission id we've already handled, for dedup
|
|
|
|
When a subreddit is first watched it is *bootstrapped silently* — the posts
|
|
currently on /new are marked seen without announcing, so adding a busy
|
|
subreddit doesn't dump its backlog into the channel. Only posts created after
|
|
that point are announced.
|
|
|
|
Subreddits are managed at runtime with slash commands:
|
|
/watch <subreddit> [channel] start announcing a subreddit's new posts
|
|
/unwatch <subreddit> stop announcing it
|
|
/watching list active watches in this server
|
|
/setinterval <seconds> change the global poll interval
|
|
|
|
Configuration comes from environment variables (see config.example.env).
|
|
Reddit access is read-only via a registered app's client id/secret — no
|
|
Reddit account login is required.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import datetime
|
|
import logging
|
|
import os
|
|
import sqlite3
|
|
import time
|
|
|
|
import asyncpraw
|
|
import discord
|
|
from discord import app_commands
|
|
from discord.ext import tasks
|
|
|
|
# --------------------------------------------------------------------------
|
|
# Configuration (env-overridable; secrets stay in config.env, never in git)
|
|
# --------------------------------------------------------------------------
|
|
|
|
DISCORD_TOKEN = os.environ.get("DISCORD_TOKEN", "")
|
|
|
|
REDDIT_CLIENT_ID = os.environ.get("REDDIT_CLIENT_ID", "")
|
|
REDDIT_CLIENT_SECRET = os.environ.get("REDDIT_CLIENT_SECRET", "")
|
|
REDDIT_USER_AGENT = os.environ.get(
|
|
"REDDIT_USER_AGENT", "discord:subreddit-announcer:1.0 (by /u/megaproxy)"
|
|
)
|
|
|
|
DB_PATH = os.environ.get(
|
|
"ANNOUNCER_DB", os.path.join(os.path.dirname(__file__), "announcer.db")
|
|
)
|
|
|
|
# Seconds between poll sweeps across all watched subreddits.
|
|
POLL_INTERVAL = float(os.environ.get("ANNOUNCER_POLL_INTERVAL", "60"))
|
|
# How many posts to pull from each subreddit's /new per sweep. 25 is plenty
|
|
# unless a sub gets >25 posts within one POLL_INTERVAL.
|
|
FETCH_LIMIT = int(os.environ.get("ANNOUNCER_FETCH_LIMIT", "25"))
|
|
|
|
logging.basicConfig(
|
|
level=os.environ.get("ANNOUNCER_LOGLEVEL", "INFO"),
|
|
format="%(asctime)s %(levelname)s %(message)s",
|
|
)
|
|
log = logging.getLogger("announcer")
|
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
# State storage
|
|
# --------------------------------------------------------------------------
|
|
|
|
|
|
class Store:
|
|
"""SQLite-backed state: which subreddits go to which channels, and which
|
|
submissions have already been announced."""
|
|
|
|
def __init__(self, path: str):
|
|
# check_same_thread=False: discord.py runs the loop in one thread but
|
|
# the connection may be touched from task callbacks on the same loop.
|
|
self.db = sqlite3.connect(path, check_same_thread=False)
|
|
self.db.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS watches (
|
|
subreddit TEXT NOT NULL,
|
|
guild_id INTEGER NOT NULL,
|
|
channel_id INTEGER NOT NULL,
|
|
added_by INTEGER NOT NULL,
|
|
added_at INTEGER NOT NULL,
|
|
bootstrapped INTEGER NOT NULL DEFAULT 0,
|
|
PRIMARY KEY (subreddit, guild_id)
|
|
)
|
|
"""
|
|
)
|
|
self.db.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS seen_posts (
|
|
post_id TEXT NOT NULL,
|
|
subreddit TEXT NOT NULL,
|
|
seen_at INTEGER NOT NULL,
|
|
PRIMARY KEY (post_id, subreddit)
|
|
)
|
|
"""
|
|
)
|
|
self.db.commit()
|
|
|
|
# --- watches ----------------------------------------------------------
|
|
|
|
def add_watch(self, subreddit: str, guild_id: int, channel_id: int, added_by: int):
|
|
"""Insert or re-point a watch. Re-watching resets the bootstrap flag so
|
|
the new channel gets a clean silent baseline rather than a backlog."""
|
|
self.db.execute(
|
|
"""
|
|
INSERT INTO watches (subreddit, guild_id, channel_id, added_by, added_at, bootstrapped)
|
|
VALUES (?, ?, ?, ?, ?, 0)
|
|
ON CONFLICT(subreddit, guild_id) DO UPDATE SET
|
|
channel_id = excluded.channel_id,
|
|
added_by = excluded.added_by,
|
|
added_at = excluded.added_at,
|
|
bootstrapped = 0
|
|
""",
|
|
(subreddit, guild_id, channel_id, added_by, int(time.time())),
|
|
)
|
|
self.db.commit()
|
|
|
|
def remove_watch(self, subreddit: str, guild_id: int) -> bool:
|
|
cur = self.db.execute(
|
|
"DELETE FROM watches WHERE subreddit = ? AND guild_id = ?",
|
|
(subreddit, guild_id),
|
|
)
|
|
self.db.commit()
|
|
return cur.rowcount > 0
|
|
|
|
def list_watches(self, guild_id: int):
|
|
return self.db.execute(
|
|
"SELECT subreddit, channel_id FROM watches WHERE guild_id = ? ORDER BY subreddit",
|
|
(guild_id,),
|
|
).fetchall()
|
|
|
|
def all_watches(self):
|
|
"""Every watch across all guilds, for the poll loop."""
|
|
return self.db.execute(
|
|
"SELECT subreddit, guild_id, channel_id, bootstrapped FROM watches"
|
|
).fetchall()
|
|
|
|
def distinct_subreddits(self):
|
|
return [
|
|
r[0]
|
|
for r in self.db.execute(
|
|
"SELECT DISTINCT subreddit FROM watches"
|
|
).fetchall()
|
|
]
|
|
|
|
def mark_bootstrapped(self, subreddit: str):
|
|
self.db.execute(
|
|
"UPDATE watches SET bootstrapped = 1 WHERE subreddit = ?", (subreddit,)
|
|
)
|
|
self.db.commit()
|
|
|
|
# --- seen posts -------------------------------------------------------
|
|
|
|
def is_seen(self, post_id: str, subreddit: str) -> bool:
|
|
return (
|
|
self.db.execute(
|
|
"SELECT 1 FROM seen_posts WHERE post_id = ? AND subreddit = ?",
|
|
(post_id, subreddit),
|
|
).fetchone()
|
|
is not None
|
|
)
|
|
|
|
def mark_seen(self, post_id: str, subreddit: str):
|
|
self.db.execute(
|
|
"INSERT OR IGNORE INTO seen_posts (post_id, subreddit, seen_at) VALUES (?, ?, ?)",
|
|
(post_id, subreddit, int(time.time())),
|
|
)
|
|
self.db.commit()
|
|
|
|
def prune_seen(self, subreddit: str, keep: int = 500):
|
|
"""Keep the seen table from growing unbounded: retain the most recent
|
|
`keep` ids per subreddit. Old posts won't reappear on /new anyway."""
|
|
self.db.execute(
|
|
"""
|
|
DELETE FROM seen_posts
|
|
WHERE subreddit = ? AND post_id NOT IN (
|
|
SELECT post_id FROM seen_posts
|
|
WHERE subreddit = ? ORDER BY seen_at DESC LIMIT ?
|
|
)
|
|
""",
|
|
(subreddit, subreddit, keep),
|
|
)
|
|
self.db.commit()
|
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
# Discord client
|
|
# --------------------------------------------------------------------------
|
|
|
|
|
|
class AnnouncerBot(discord.Client):
|
|
def __init__(self, store: Store):
|
|
# No privileged intents needed: we only post, never read messages.
|
|
intents = discord.Intents.default()
|
|
super().__init__(intents=intents)
|
|
self.store = store
|
|
self.tree = app_commands.CommandTree(self)
|
|
self.reddit: asyncpraw.Reddit | None = None
|
|
self.poll_interval = POLL_INTERVAL
|
|
|
|
async def setup_hook(self):
|
|
self.reddit = asyncpraw.Reddit(
|
|
client_id=REDDIT_CLIENT_ID,
|
|
client_secret=REDDIT_CLIENT_SECRET,
|
|
user_agent=REDDIT_USER_AGENT,
|
|
)
|
|
self.reddit.read_only = True
|
|
await self.tree.sync()
|
|
self.poll_loop.change_interval(seconds=self.poll_interval)
|
|
self.poll_loop.start()
|
|
|
|
async def on_ready(self):
|
|
log.info("Logged in as %s (id %s)", self.user, self.user.id)
|
|
log.info(
|
|
"Watching %d subreddit(s) across %d guild(s)",
|
|
len(self.store.distinct_subreddits()),
|
|
len(self.guilds),
|
|
)
|
|
|
|
async def close(self):
|
|
if self.reddit:
|
|
await self.reddit.close()
|
|
await super().close()
|
|
|
|
# --- the poll loop ----------------------------------------------------
|
|
|
|
@tasks.loop(seconds=POLL_INTERVAL)
|
|
async def poll_loop(self):
|
|
# Group watches by subreddit so each is fetched once per sweep even if
|
|
# several channels (or guilds) want it.
|
|
watches = self.store.all_watches()
|
|
by_sub: dict[str, list[tuple[int, int, int]]] = {}
|
|
for subreddit, guild_id, channel_id, bootstrapped in watches:
|
|
by_sub.setdefault(subreddit, []).append(
|
|
(guild_id, channel_id, bootstrapped)
|
|
)
|
|
|
|
for subreddit, targets in by_sub.items():
|
|
try:
|
|
await self._poll_subreddit(subreddit, targets)
|
|
except Exception:
|
|
# One bad subreddit (deleted, private, banned) must not kill
|
|
# the whole sweep — log and move on.
|
|
log.exception("Error polling r/%s", subreddit)
|
|
|
|
async def _poll_subreddit(self, subreddit: str, targets: list[tuple[int, int, int]]):
|
|
sub = await self.reddit.subreddit(subreddit)
|
|
# Newest first from Reddit; reverse so we announce oldest->newest.
|
|
posts = [s async for s in sub.new(limit=FETCH_LIMIT)]
|
|
posts.reverse()
|
|
|
|
first_run = any(b == 0 for (_, _, b) in targets)
|
|
if first_run:
|
|
# Silent baseline: mark everything currently on /new as seen and
|
|
# don't announce, so a freshly-added subreddit doesn't dump its
|
|
# backlog. Flip the flag for every target row of this subreddit.
|
|
for post in posts:
|
|
self.store.mark_seen(post.id, subreddit)
|
|
self.store.mark_bootstrapped(subreddit)
|
|
log.info(
|
|
"Bootstrapped r/%s (%d existing posts marked seen)",
|
|
subreddit,
|
|
len(posts),
|
|
)
|
|
return
|
|
|
|
for post in posts:
|
|
if self.store.is_seen(post.id, subreddit):
|
|
continue
|
|
self.store.mark_seen(post.id, subreddit)
|
|
embed = self._build_embed(subreddit, post)
|
|
for guild_id, channel_id, _ in targets:
|
|
channel = self.get_channel(channel_id)
|
|
if channel is None:
|
|
log.warning(
|
|
"Channel %s (guild %s) for r/%s not found; skipping",
|
|
channel_id,
|
|
guild_id,
|
|
subreddit,
|
|
)
|
|
continue
|
|
try:
|
|
await channel.send(embed=embed)
|
|
except discord.DiscordException:
|
|
log.exception(
|
|
"Failed to post r/%s/%s to channel %s",
|
|
subreddit,
|
|
post.id,
|
|
channel_id,
|
|
)
|
|
|
|
self.store.prune_seen(subreddit)
|
|
|
|
@poll_loop.before_loop
|
|
async def _before_poll(self):
|
|
await self.wait_until_ready()
|
|
|
|
@staticmethod
|
|
def _build_embed(subreddit: str, post) -> discord.Embed:
|
|
title = post.title
|
|
if len(title) > 256: # Discord embed title hard limit
|
|
title = title[:253] + "…"
|
|
permalink = f"https://reddit.com{post.permalink}"
|
|
embed = discord.Embed(
|
|
title=title,
|
|
url=permalink,
|
|
color=0xFF4500, # reddit orange
|
|
)
|
|
embed.set_author(name=f"r/{subreddit}", url=f"https://reddit.com/r/{subreddit}")
|
|
author = getattr(post, "author", None)
|
|
embed.add_field(
|
|
name="Author", value=f"u/{author}" if author else "[deleted]", inline=True
|
|
)
|
|
# Self-text preview for text posts; thumbnail/link for the rest.
|
|
if getattr(post, "is_self", False) and post.selftext:
|
|
body = post.selftext
|
|
if len(body) > 500:
|
|
body = body[:497] + "…"
|
|
embed.description = body
|
|
elif getattr(post, "url", None):
|
|
url = post.url
|
|
if any(url.lower().endswith(ext) for ext in (".jpg", ".jpeg", ".png", ".gif", ".webp")):
|
|
embed.set_image(url=url)
|
|
else:
|
|
embed.add_field(name="Link", value=url, inline=False)
|
|
embed.set_footer(text="Reddit")
|
|
created = getattr(post, "created_utc", None)
|
|
if created:
|
|
embed.timestamp = datetime.datetime.fromtimestamp(
|
|
created, tz=datetime.timezone.utc
|
|
)
|
|
return embed
|
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
# Slash commands
|
|
# --------------------------------------------------------------------------
|
|
|
|
|
|
def _clean_subreddit(name: str) -> str:
|
|
"""Normalise user input like 'r/Python', '/r/python ', 'Python' -> 'python'."""
|
|
name = name.strip().lstrip("/")
|
|
if name.lower().startswith("r/"):
|
|
name = name[2:]
|
|
return name.strip().lower()
|
|
|
|
|
|
def register_commands(bot: AnnouncerBot):
|
|
tree = bot.tree
|
|
|
|
@tree.command(name="watch", description="Announce a subreddit's new posts in a channel.")
|
|
@app_commands.describe(
|
|
subreddit="Subreddit name, e.g. python or r/python",
|
|
channel="Channel to post in (defaults to the current channel)",
|
|
)
|
|
async def watch(
|
|
interaction: discord.Interaction,
|
|
subreddit: str,
|
|
channel: discord.TextChannel | None = None,
|
|
):
|
|
if interaction.guild_id is None:
|
|
await interaction.response.send_message(
|
|
"Use this in a server, not a DM.", ephemeral=True
|
|
)
|
|
return
|
|
name = _clean_subreddit(subreddit)
|
|
if not name:
|
|
await interaction.response.send_message(
|
|
"That doesn't look like a subreddit name.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
target = channel or interaction.channel
|
|
# Verify the subreddit exists and is readable before committing.
|
|
await interaction.response.defer(ephemeral=True)
|
|
try:
|
|
sub = await bot.reddit.subreddit(name)
|
|
await sub.load() # raises if it doesn't exist / is private / banned
|
|
except Exception as exc:
|
|
log.info("Rejected r/%s: %s", name, exc)
|
|
await interaction.followup.send(
|
|
f"Couldn't access **r/{name}** — check the name (and that it's public).",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
bot.store.add_watch(name, interaction.guild_id, target.id, interaction.user.id)
|
|
await interaction.followup.send(
|
|
f"Now watching **r/{name}** → {target.mention}. "
|
|
f"Existing posts are skipped; you'll get new ones from here on.",
|
|
ephemeral=True,
|
|
)
|
|
|
|
@tree.command(name="unwatch", description="Stop announcing a subreddit in this server.")
|
|
@app_commands.describe(subreddit="Subreddit to stop watching")
|
|
async def unwatch(interaction: discord.Interaction, subreddit: str):
|
|
if interaction.guild_id is None:
|
|
await interaction.response.send_message(
|
|
"Use this in a server, not a DM.", ephemeral=True
|
|
)
|
|
return
|
|
name = _clean_subreddit(subreddit)
|
|
removed = bot.store.remove_watch(name, interaction.guild_id)
|
|
if removed:
|
|
await interaction.response.send_message(
|
|
f"Stopped watching **r/{name}**.", ephemeral=True
|
|
)
|
|
else:
|
|
await interaction.response.send_message(
|
|
f"Wasn't watching **r/{name}** here.", ephemeral=True
|
|
)
|
|
|
|
@tree.command(name="watching", description="List the subreddits announced in this server.")
|
|
async def watching(interaction: discord.Interaction):
|
|
if interaction.guild_id is None:
|
|
await interaction.response.send_message(
|
|
"Use this in a server, not a DM.", ephemeral=True
|
|
)
|
|
return
|
|
rows = bot.store.list_watches(interaction.guild_id)
|
|
if not rows:
|
|
await interaction.response.send_message(
|
|
"Not watching any subreddits here yet. Add one with `/watch`.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
lines = [f"• **r/{sub}** → <#{chan}>" for sub, chan in rows]
|
|
await interaction.response.send_message("\n".join(lines), ephemeral=True)
|
|
|
|
@tree.command(name="setinterval", description="Set how often (seconds) new posts are checked.")
|
|
@app_commands.describe(seconds="Poll interval in seconds (minimum 15)")
|
|
async def setinterval(interaction: discord.Interaction, seconds: int):
|
|
if seconds < 15:
|
|
await interaction.response.send_message(
|
|
"Minimum interval is 15 seconds (Reddit rate limits).", ephemeral=True
|
|
)
|
|
return
|
|
bot.poll_interval = float(seconds)
|
|
bot.poll_loop.change_interval(seconds=float(seconds))
|
|
await interaction.response.send_message(
|
|
f"Poll interval set to {seconds}s.", ephemeral=True
|
|
)
|
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
# Entry point
|
|
# --------------------------------------------------------------------------
|
|
|
|
|
|
def main():
|
|
if not DISCORD_TOKEN:
|
|
raise SystemExit("DISCORD_TOKEN is not set — see config.example.env")
|
|
if not (REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET):
|
|
raise SystemExit(
|
|
"REDDIT_CLIENT_ID / REDDIT_CLIENT_SECRET not set — see config.example.env"
|
|
)
|
|
|
|
store = Store(DB_PATH)
|
|
log.info("State DB: %s", DB_PATH)
|
|
bot = AnnouncerBot(store)
|
|
register_commands(bot)
|
|
bot.run(DISCORD_TOKEN, log_handler=None) # we configure logging ourselves
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|