Initial scaffold: Discord subreddit-announcer bot
This commit is contained in:
commit
705c6ba9f7
8 changed files with 1715 additions and 0 deletions
475
bot.py
Normal file
475
bot.py
Normal file
|
|
@ -0,0 +1,475 @@
|
|||
#!/usr/bin/env python3
|
||||
"""subreddit-announcer — a Discord bot that announces new Reddit posts.
|
||||
|
||||
A long-running discord.py bot. It polls the `/new` listing of each watched
|
||||
subreddit on a fixed interval and posts an embed to the mapped Discord channel
|
||||
for every submission it hasn't seen before.
|
||||
|
||||
State lives in a local SQLite DB:
|
||||
- `watches` : subreddit -> (guild, channel) mapping + bootstrap flag
|
||||
- `seen_posts` : every submission id we've already handled, for dedup
|
||||
|
||||
When a subreddit is first watched it is *bootstrapped silently* — the posts
|
||||
currently on /new are marked seen without announcing, so adding a busy
|
||||
subreddit doesn't dump its backlog into the channel. Only posts created after
|
||||
that point are announced.
|
||||
|
||||
Subreddits are managed at runtime with slash commands:
|
||||
/watch <subreddit> [channel] start announcing a subreddit's new posts
|
||||
/unwatch <subreddit> stop announcing it
|
||||
/watching list active watches in this server
|
||||
/setinterval <seconds> change the global poll interval
|
||||
|
||||
Configuration comes from environment variables (see config.example.env).
|
||||
Reddit access is read-only via a registered app's client id/secret — no
|
||||
Reddit account login is required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
|
||||
import asyncpraw
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import tasks
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Configuration (env-overridable; secrets stay in config.env, never in git)
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
DISCORD_TOKEN = os.environ.get("DISCORD_TOKEN", "")
|
||||
|
||||
REDDIT_CLIENT_ID = os.environ.get("REDDIT_CLIENT_ID", "")
|
||||
REDDIT_CLIENT_SECRET = os.environ.get("REDDIT_CLIENT_SECRET", "")
|
||||
REDDIT_USER_AGENT = os.environ.get(
|
||||
"REDDIT_USER_AGENT", "discord:subreddit-announcer:1.0 (by /u/megaproxy)"
|
||||
)
|
||||
|
||||
DB_PATH = os.environ.get(
|
||||
"ANNOUNCER_DB", os.path.join(os.path.dirname(__file__), "announcer.db")
|
||||
)
|
||||
|
||||
# Seconds between poll sweeps across all watched subreddits.
|
||||
POLL_INTERVAL = float(os.environ.get("ANNOUNCER_POLL_INTERVAL", "60"))
|
||||
# How many posts to pull from each subreddit's /new per sweep. 25 is plenty
|
||||
# unless a sub gets >25 posts within one POLL_INTERVAL.
|
||||
FETCH_LIMIT = int(os.environ.get("ANNOUNCER_FETCH_LIMIT", "25"))
|
||||
|
||||
logging.basicConfig(
|
||||
level=os.environ.get("ANNOUNCER_LOGLEVEL", "INFO"),
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
)
|
||||
log = logging.getLogger("announcer")
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# State storage
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Store:
|
||||
"""SQLite-backed state: which subreddits go to which channels, and which
|
||||
submissions have already been announced."""
|
||||
|
||||
def __init__(self, path: str):
|
||||
# check_same_thread=False: discord.py runs the loop in one thread but
|
||||
# the connection may be touched from task callbacks on the same loop.
|
||||
self.db = sqlite3.connect(path, check_same_thread=False)
|
||||
self.db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS watches (
|
||||
subreddit TEXT NOT NULL,
|
||||
guild_id INTEGER NOT NULL,
|
||||
channel_id INTEGER NOT NULL,
|
||||
added_by INTEGER NOT NULL,
|
||||
added_at INTEGER NOT NULL,
|
||||
bootstrapped INTEGER NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (subreddit, guild_id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
self.db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS seen_posts (
|
||||
post_id TEXT NOT NULL,
|
||||
subreddit TEXT NOT NULL,
|
||||
seen_at INTEGER NOT NULL,
|
||||
PRIMARY KEY (post_id, subreddit)
|
||||
)
|
||||
"""
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
# --- watches ----------------------------------------------------------
|
||||
|
||||
def add_watch(self, subreddit: str, guild_id: int, channel_id: int, added_by: int):
|
||||
"""Insert or re-point a watch. Re-watching resets the bootstrap flag so
|
||||
the new channel gets a clean silent baseline rather than a backlog."""
|
||||
self.db.execute(
|
||||
"""
|
||||
INSERT INTO watches (subreddit, guild_id, channel_id, added_by, added_at, bootstrapped)
|
||||
VALUES (?, ?, ?, ?, ?, 0)
|
||||
ON CONFLICT(subreddit, guild_id) DO UPDATE SET
|
||||
channel_id = excluded.channel_id,
|
||||
added_by = excluded.added_by,
|
||||
added_at = excluded.added_at,
|
||||
bootstrapped = 0
|
||||
""",
|
||||
(subreddit, guild_id, channel_id, added_by, int(time.time())),
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
def remove_watch(self, subreddit: str, guild_id: int) -> bool:
|
||||
cur = self.db.execute(
|
||||
"DELETE FROM watches WHERE subreddit = ? AND guild_id = ?",
|
||||
(subreddit, guild_id),
|
||||
)
|
||||
self.db.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
def list_watches(self, guild_id: int):
|
||||
return self.db.execute(
|
||||
"SELECT subreddit, channel_id FROM watches WHERE guild_id = ? ORDER BY subreddit",
|
||||
(guild_id,),
|
||||
).fetchall()
|
||||
|
||||
def all_watches(self):
|
||||
"""Every watch across all guilds, for the poll loop."""
|
||||
return self.db.execute(
|
||||
"SELECT subreddit, guild_id, channel_id, bootstrapped FROM watches"
|
||||
).fetchall()
|
||||
|
||||
def distinct_subreddits(self):
|
||||
return [
|
||||
r[0]
|
||||
for r in self.db.execute(
|
||||
"SELECT DISTINCT subreddit FROM watches"
|
||||
).fetchall()
|
||||
]
|
||||
|
||||
def mark_bootstrapped(self, subreddit: str):
|
||||
self.db.execute(
|
||||
"UPDATE watches SET bootstrapped = 1 WHERE subreddit = ?", (subreddit,)
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
# --- seen posts -------------------------------------------------------
|
||||
|
||||
def is_seen(self, post_id: str, subreddit: str) -> bool:
|
||||
return (
|
||||
self.db.execute(
|
||||
"SELECT 1 FROM seen_posts WHERE post_id = ? AND subreddit = ?",
|
||||
(post_id, subreddit),
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
def mark_seen(self, post_id: str, subreddit: str):
|
||||
self.db.execute(
|
||||
"INSERT OR IGNORE INTO seen_posts (post_id, subreddit, seen_at) VALUES (?, ?, ?)",
|
||||
(post_id, subreddit, int(time.time())),
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
def prune_seen(self, subreddit: str, keep: int = 500):
|
||||
"""Keep the seen table from growing unbounded: retain the most recent
|
||||
`keep` ids per subreddit. Old posts won't reappear on /new anyway."""
|
||||
self.db.execute(
|
||||
"""
|
||||
DELETE FROM seen_posts
|
||||
WHERE subreddit = ? AND post_id NOT IN (
|
||||
SELECT post_id FROM seen_posts
|
||||
WHERE subreddit = ? ORDER BY seen_at DESC LIMIT ?
|
||||
)
|
||||
""",
|
||||
(subreddit, subreddit, keep),
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Discord client
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AnnouncerBot(discord.Client):
|
||||
def __init__(self, store: Store):
|
||||
# No privileged intents needed: we only post, never read messages.
|
||||
intents = discord.Intents.default()
|
||||
super().__init__(intents=intents)
|
||||
self.store = store
|
||||
self.tree = app_commands.CommandTree(self)
|
||||
self.reddit: asyncpraw.Reddit | None = None
|
||||
self.poll_interval = POLL_INTERVAL
|
||||
|
||||
async def setup_hook(self):
|
||||
self.reddit = asyncpraw.Reddit(
|
||||
client_id=REDDIT_CLIENT_ID,
|
||||
client_secret=REDDIT_CLIENT_SECRET,
|
||||
user_agent=REDDIT_USER_AGENT,
|
||||
)
|
||||
self.reddit.read_only = True
|
||||
await self.tree.sync()
|
||||
self.poll_loop.change_interval(seconds=self.poll_interval)
|
||||
self.poll_loop.start()
|
||||
|
||||
async def on_ready(self):
|
||||
log.info("Logged in as %s (id %s)", self.user, self.user.id)
|
||||
log.info(
|
||||
"Watching %d subreddit(s) across %d guild(s)",
|
||||
len(self.store.distinct_subreddits()),
|
||||
len(self.guilds),
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
if self.reddit:
|
||||
await self.reddit.close()
|
||||
await super().close()
|
||||
|
||||
# --- the poll loop ----------------------------------------------------
|
||||
|
||||
@tasks.loop(seconds=POLL_INTERVAL)
|
||||
async def poll_loop(self):
|
||||
# Group watches by subreddit so each is fetched once per sweep even if
|
||||
# several channels (or guilds) want it.
|
||||
watches = self.store.all_watches()
|
||||
by_sub: dict[str, list[tuple[int, int, int]]] = {}
|
||||
for subreddit, guild_id, channel_id, bootstrapped in watches:
|
||||
by_sub.setdefault(subreddit, []).append(
|
||||
(guild_id, channel_id, bootstrapped)
|
||||
)
|
||||
|
||||
for subreddit, targets in by_sub.items():
|
||||
try:
|
||||
await self._poll_subreddit(subreddit, targets)
|
||||
except Exception:
|
||||
# One bad subreddit (deleted, private, banned) must not kill
|
||||
# the whole sweep — log and move on.
|
||||
log.exception("Error polling r/%s", subreddit)
|
||||
|
||||
async def _poll_subreddit(self, subreddit: str, targets: list[tuple[int, int, int]]):
|
||||
sub = await self.reddit.subreddit(subreddit)
|
||||
# Newest first from Reddit; reverse so we announce oldest->newest.
|
||||
posts = [s async for s in sub.new(limit=FETCH_LIMIT)]
|
||||
posts.reverse()
|
||||
|
||||
first_run = any(b == 0 for (_, _, b) in targets)
|
||||
if first_run:
|
||||
# Silent baseline: mark everything currently on /new as seen and
|
||||
# don't announce, so a freshly-added subreddit doesn't dump its
|
||||
# backlog. Flip the flag for every target row of this subreddit.
|
||||
for post in posts:
|
||||
self.store.mark_seen(post.id, subreddit)
|
||||
self.store.mark_bootstrapped(subreddit)
|
||||
log.info(
|
||||
"Bootstrapped r/%s (%d existing posts marked seen)",
|
||||
subreddit,
|
||||
len(posts),
|
||||
)
|
||||
return
|
||||
|
||||
for post in posts:
|
||||
if self.store.is_seen(post.id, subreddit):
|
||||
continue
|
||||
self.store.mark_seen(post.id, subreddit)
|
||||
embed = self._build_embed(subreddit, post)
|
||||
for guild_id, channel_id, _ in targets:
|
||||
channel = self.get_channel(channel_id)
|
||||
if channel is None:
|
||||
log.warning(
|
||||
"Channel %s (guild %s) for r/%s not found; skipping",
|
||||
channel_id,
|
||||
guild_id,
|
||||
subreddit,
|
||||
)
|
||||
continue
|
||||
try:
|
||||
await channel.send(embed=embed)
|
||||
except discord.DiscordException:
|
||||
log.exception(
|
||||
"Failed to post r/%s/%s to channel %s",
|
||||
subreddit,
|
||||
post.id,
|
||||
channel_id,
|
||||
)
|
||||
|
||||
self.store.prune_seen(subreddit)
|
||||
|
||||
@poll_loop.before_loop
|
||||
async def _before_poll(self):
|
||||
await self.wait_until_ready()
|
||||
|
||||
@staticmethod
|
||||
def _build_embed(subreddit: str, post) -> discord.Embed:
|
||||
title = post.title
|
||||
if len(title) > 256: # Discord embed title hard limit
|
||||
title = title[:253] + "…"
|
||||
permalink = f"https://reddit.com{post.permalink}"
|
||||
embed = discord.Embed(
|
||||
title=title,
|
||||
url=permalink,
|
||||
color=0xFF4500, # reddit orange
|
||||
)
|
||||
embed.set_author(name=f"r/{subreddit}", url=f"https://reddit.com/r/{subreddit}")
|
||||
author = getattr(post, "author", None)
|
||||
embed.add_field(
|
||||
name="Author", value=f"u/{author}" if author else "[deleted]", inline=True
|
||||
)
|
||||
# Self-text preview for text posts; thumbnail/link for the rest.
|
||||
if getattr(post, "is_self", False) and post.selftext:
|
||||
body = post.selftext
|
||||
if len(body) > 500:
|
||||
body = body[:497] + "…"
|
||||
embed.description = body
|
||||
elif getattr(post, "url", None):
|
||||
url = post.url
|
||||
if any(url.lower().endswith(ext) for ext in (".jpg", ".jpeg", ".png", ".gif", ".webp")):
|
||||
embed.set_image(url=url)
|
||||
else:
|
||||
embed.add_field(name="Link", value=url, inline=False)
|
||||
embed.set_footer(text="Reddit")
|
||||
created = getattr(post, "created_utc", None)
|
||||
if created:
|
||||
embed.timestamp = datetime.datetime.fromtimestamp(
|
||||
created, tz=datetime.timezone.utc
|
||||
)
|
||||
return embed
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Slash commands
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _clean_subreddit(name: str) -> str:
|
||||
"""Normalise user input like 'r/Python', '/r/python ', 'Python' -> 'python'."""
|
||||
name = name.strip().lstrip("/")
|
||||
if name.lower().startswith("r/"):
|
||||
name = name[2:]
|
||||
return name.strip().lower()
|
||||
|
||||
|
||||
def register_commands(bot: AnnouncerBot):
|
||||
tree = bot.tree
|
||||
|
||||
@tree.command(name="watch", description="Announce a subreddit's new posts in a channel.")
|
||||
@app_commands.describe(
|
||||
subreddit="Subreddit name, e.g. python or r/python",
|
||||
channel="Channel to post in (defaults to the current channel)",
|
||||
)
|
||||
async def watch(
|
||||
interaction: discord.Interaction,
|
||||
subreddit: str,
|
||||
channel: discord.TextChannel | None = None,
|
||||
):
|
||||
if interaction.guild_id is None:
|
||||
await interaction.response.send_message(
|
||||
"Use this in a server, not a DM.", ephemeral=True
|
||||
)
|
||||
return
|
||||
name = _clean_subreddit(subreddit)
|
||||
if not name:
|
||||
await interaction.response.send_message(
|
||||
"That doesn't look like a subreddit name.", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
target = channel or interaction.channel
|
||||
# Verify the subreddit exists and is readable before committing.
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
try:
|
||||
sub = await bot.reddit.subreddit(name)
|
||||
await sub.load() # raises if it doesn't exist / is private / banned
|
||||
except Exception as exc:
|
||||
log.info("Rejected r/%s: %s", name, exc)
|
||||
await interaction.followup.send(
|
||||
f"Couldn't access **r/{name}** — check the name (and that it's public).",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
bot.store.add_watch(name, interaction.guild_id, target.id, interaction.user.id)
|
||||
await interaction.followup.send(
|
||||
f"Now watching **r/{name}** → {target.mention}. "
|
||||
f"Existing posts are skipped; you'll get new ones from here on.",
|
||||
ephemeral=True,
|
||||
)
|
||||
|
||||
@tree.command(name="unwatch", description="Stop announcing a subreddit in this server.")
|
||||
@app_commands.describe(subreddit="Subreddit to stop watching")
|
||||
async def unwatch(interaction: discord.Interaction, subreddit: str):
|
||||
if interaction.guild_id is None:
|
||||
await interaction.response.send_message(
|
||||
"Use this in a server, not a DM.", ephemeral=True
|
||||
)
|
||||
return
|
||||
name = _clean_subreddit(subreddit)
|
||||
removed = bot.store.remove_watch(name, interaction.guild_id)
|
||||
if removed:
|
||||
await interaction.response.send_message(
|
||||
f"Stopped watching **r/{name}**.", ephemeral=True
|
||||
)
|
||||
else:
|
||||
await interaction.response.send_message(
|
||||
f"Wasn't watching **r/{name}** here.", ephemeral=True
|
||||
)
|
||||
|
||||
@tree.command(name="watching", description="List the subreddits announced in this server.")
|
||||
async def watching(interaction: discord.Interaction):
|
||||
if interaction.guild_id is None:
|
||||
await interaction.response.send_message(
|
||||
"Use this in a server, not a DM.", ephemeral=True
|
||||
)
|
||||
return
|
||||
rows = bot.store.list_watches(interaction.guild_id)
|
||||
if not rows:
|
||||
await interaction.response.send_message(
|
||||
"Not watching any subreddits here yet. Add one with `/watch`.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
lines = [f"• **r/{sub}** → <#{chan}>" for sub, chan in rows]
|
||||
await interaction.response.send_message("\n".join(lines), ephemeral=True)
|
||||
|
||||
@tree.command(name="setinterval", description="Set how often (seconds) new posts are checked.")
|
||||
@app_commands.describe(seconds="Poll interval in seconds (minimum 15)")
|
||||
async def setinterval(interaction: discord.Interaction, seconds: int):
|
||||
if seconds < 15:
|
||||
await interaction.response.send_message(
|
||||
"Minimum interval is 15 seconds (Reddit rate limits).", ephemeral=True
|
||||
)
|
||||
return
|
||||
bot.poll_interval = float(seconds)
|
||||
bot.poll_loop.change_interval(seconds=float(seconds))
|
||||
await interaction.response.send_message(
|
||||
f"Poll interval set to {seconds}s.", ephemeral=True
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
if not DISCORD_TOKEN:
|
||||
raise SystemExit("DISCORD_TOKEN is not set — see config.example.env")
|
||||
if not (REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET):
|
||||
raise SystemExit(
|
||||
"REDDIT_CLIENT_ID / REDDIT_CLIENT_SECRET not set — see config.example.env"
|
||||
)
|
||||
|
||||
store = Store(DB_PATH)
|
||||
log.info("State DB: %s", DB_PATH)
|
||||
bot = AnnouncerBot(store)
|
||||
register_commands(bot)
|
||||
bot.run(DISCORD_TOKEN, log_handler=None) # we configure logging ourselves
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue