- Fix dirty-user flush race: discard IDs individually after successful save - Escape LIKE wildcards in LLM-generated topic keywords for DB queries - Anonymize absent-member aliases to prevent LLM de-anonymization - Pass correct MIME type to vision model based on image file extension - Use enumerate instead of list.index() in bcs-scan loop - Allow bot @mentions with non-report intent to fall through to moderation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
816 lines
31 KiB
Python
816 lines
31 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
from datetime import datetime, timezone
|
|
|
|
logger = logging.getLogger("bcs.database")
|
|
|
|
|
|
class Database:
|
|
def __init__(self):
|
|
self._conn_str = os.getenv("DB_CONNECTION_STRING", "")
|
|
self._available = False
|
|
|
|
async def init(self) -> bool:
|
|
"""Initialize the database connection and create schema.
|
|
Returns True if DB is available, False for memory-only mode."""
|
|
if not self._conn_str:
|
|
logger.warning("DB_CONNECTION_STRING not set — running in memory-only mode.")
|
|
return False
|
|
|
|
try:
|
|
import pyodbc
|
|
self._pyodbc = pyodbc
|
|
except ImportError:
|
|
logger.warning("pyodbc not installed — running in memory-only mode.")
|
|
return False
|
|
|
|
try:
|
|
conn = await asyncio.to_thread(self._connect)
|
|
await asyncio.to_thread(self._create_schema, conn)
|
|
conn.close()
|
|
self._available = True
|
|
logger.info("Database initialized successfully.")
|
|
return True
|
|
except Exception:
|
|
logger.exception("Database initialization failed — running in memory-only mode.")
|
|
return False
|
|
|
|
def _connect(self):
|
|
return self._pyodbc.connect(self._conn_str, autocommit=True)
|
|
|
|
def _create_schema(self, conn):
|
|
cursor = conn.cursor()
|
|
|
|
# Create database if it doesn't exist
|
|
db_name = self._parse_database_name()
|
|
if db_name:
|
|
cursor.execute(
|
|
f"IF DB_ID('{db_name}') IS NULL CREATE DATABASE [{db_name}]"
|
|
)
|
|
cursor.execute(f"USE [{db_name}]")
|
|
|
|
cursor.execute("""
|
|
IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'Messages')
|
|
CREATE TABLE Messages (
|
|
Id BIGINT IDENTITY(1,1) PRIMARY KEY,
|
|
GuildId BIGINT NOT NULL,
|
|
ChannelId BIGINT NOT NULL,
|
|
UserId BIGINT NOT NULL,
|
|
Username NVARCHAR(100) NOT NULL,
|
|
Content NVARCHAR(MAX) NOT NULL,
|
|
MessageTs DATETIME2 NOT NULL,
|
|
CreatedAt DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME()
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'AnalysisResults')
|
|
CREATE TABLE AnalysisResults (
|
|
Id BIGINT IDENTITY(1,1) PRIMARY KEY,
|
|
MessageId BIGINT NOT NULL REFERENCES Messages(Id),
|
|
ToxicityScore FLOAT NOT NULL,
|
|
DramaScore FLOAT NOT NULL,
|
|
Categories NVARCHAR(500) NOT NULL,
|
|
Reasoning NVARCHAR(MAX) NOT NULL,
|
|
OffTopic BIT NOT NULL DEFAULT 0,
|
|
TopicCategory NVARCHAR(100) NULL,
|
|
TopicReasoning NVARCHAR(MAX) NULL,
|
|
CreatedAt DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME()
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'Actions')
|
|
CREATE TABLE Actions (
|
|
Id BIGINT IDENTITY(1,1) PRIMARY KEY,
|
|
GuildId BIGINT NOT NULL,
|
|
UserId BIGINT NOT NULL,
|
|
Username NVARCHAR(100) NOT NULL,
|
|
ActionType NVARCHAR(50) NOT NULL,
|
|
MessageId BIGINT NULL REFERENCES Messages(Id),
|
|
Details NVARCHAR(MAX) NULL,
|
|
CreatedAt DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME()
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'UserState')
|
|
CREATE TABLE UserState (
|
|
UserId BIGINT NOT NULL PRIMARY KEY,
|
|
OffenseCount INT NOT NULL DEFAULT 0,
|
|
Immune BIT NOT NULL DEFAULT 0,
|
|
OffTopicCount INT NOT NULL DEFAULT 0,
|
|
UpdatedAt DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME()
|
|
)
|
|
""")
|
|
|
|
# --- Schema migrations for coherence feature ---
|
|
cursor.execute("""
|
|
IF COL_LENGTH('AnalysisResults', 'CoherenceScore') IS NULL
|
|
ALTER TABLE AnalysisResults ADD CoherenceScore FLOAT NULL
|
|
""")
|
|
cursor.execute("""
|
|
IF COL_LENGTH('AnalysisResults', 'CoherenceFlag') IS NULL
|
|
ALTER TABLE AnalysisResults ADD CoherenceFlag NVARCHAR(50) NULL
|
|
""")
|
|
cursor.execute("""
|
|
IF COL_LENGTH('UserState', 'BaselineCoherence') IS NULL
|
|
ALTER TABLE UserState ADD BaselineCoherence FLOAT NOT NULL DEFAULT 0.85
|
|
""")
|
|
|
|
# --- Schema migration for per-user LLM notes ---
|
|
cursor.execute("""
|
|
IF COL_LENGTH('UserState', 'UserNotes') IS NULL
|
|
ALTER TABLE UserState ADD UserNotes NVARCHAR(MAX) NULL
|
|
""")
|
|
|
|
# --- Schema migration for warned flag (require warning before mute) ---
|
|
cursor.execute("""
|
|
IF COL_LENGTH('UserState', 'Warned') IS NULL
|
|
ALTER TABLE UserState ADD Warned BIT NOT NULL DEFAULT 0
|
|
""")
|
|
|
|
# --- Schema migration for persisting last offense time ---
|
|
cursor.execute("""
|
|
IF COL_LENGTH('UserState', 'LastOffenseAt') IS NULL
|
|
ALTER TABLE UserState ADD LastOffenseAt FLOAT NULL
|
|
""")
|
|
|
|
# --- Schema migration for user aliases/nicknames ---
|
|
cursor.execute("""
|
|
IF COL_LENGTH('UserState', 'Aliases') IS NULL
|
|
ALTER TABLE UserState ADD Aliases NVARCHAR(500) NULL
|
|
""")
|
|
|
|
# --- Schema migration for warning expiration ---
|
|
cursor.execute("""
|
|
IF COL_LENGTH('UserState', 'WarningExpiresAt') IS NULL
|
|
ALTER TABLE UserState ADD WarningExpiresAt FLOAT NULL
|
|
""")
|
|
|
|
cursor.execute("""
|
|
IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'BotSettings')
|
|
CREATE TABLE BotSettings (
|
|
SettingKey NVARCHAR(100) NOT NULL PRIMARY KEY,
|
|
SettingValue NVARCHAR(MAX) NULL,
|
|
UpdatedAt DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME()
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'LlmLog')
|
|
CREATE TABLE LlmLog (
|
|
Id BIGINT IDENTITY(1,1) PRIMARY KEY,
|
|
RequestType NVARCHAR(50) NOT NULL,
|
|
Model NVARCHAR(100) NOT NULL,
|
|
InputTokens INT NULL,
|
|
OutputTokens INT NULL,
|
|
DurationMs INT NOT NULL,
|
|
Success BIT NOT NULL,
|
|
Request NVARCHAR(MAX) NOT NULL,
|
|
Response NVARCHAR(MAX) NULL,
|
|
Error NVARCHAR(MAX) NULL,
|
|
CreatedAt DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME()
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'UserMemory')
|
|
CREATE TABLE UserMemory (
|
|
Id BIGINT IDENTITY(1,1) PRIMARY KEY,
|
|
UserId BIGINT NOT NULL,
|
|
Memory NVARCHAR(500) NOT NULL,
|
|
Topics NVARCHAR(200) NOT NULL,
|
|
Importance NVARCHAR(10) NOT NULL,
|
|
ExpiresAt DATETIME2 NOT NULL,
|
|
Source NVARCHAR(20) NOT NULL,
|
|
CreatedAt DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME(),
|
|
INDEX IX_UserMemory_UserId (UserId),
|
|
INDEX IX_UserMemory_ExpiresAt (ExpiresAt)
|
|
)
|
|
""")
|
|
|
|
cursor.close()
|
|
|
|
def _parse_database_name(self) -> str:
|
|
"""Extract DATABASE= value from the connection string."""
|
|
for part in self._conn_str.split(";"):
|
|
if part.strip().upper().startswith("DATABASE="):
|
|
return part.split("=", 1)[1].strip()
|
|
return ""
|
|
|
|
# ------------------------------------------------------------------
|
|
# Message + Analysis (awaited — we need the returned message ID)
|
|
# ------------------------------------------------------------------
|
|
async def save_message_and_analysis(
|
|
self,
|
|
guild_id: int,
|
|
channel_id: int,
|
|
user_id: int,
|
|
username: str,
|
|
content: str,
|
|
message_ts: datetime,
|
|
toxicity_score: float,
|
|
drama_score: float,
|
|
categories: list[str],
|
|
reasoning: str,
|
|
off_topic: bool = False,
|
|
topic_category: str | None = None,
|
|
topic_reasoning: str | None = None,
|
|
coherence_score: float | None = None,
|
|
coherence_flag: str | None = None,
|
|
) -> int | None:
|
|
"""Save a message and its analysis result. Returns the message row ID."""
|
|
if not self._available:
|
|
return None
|
|
try:
|
|
return await asyncio.to_thread(
|
|
self._save_message_and_analysis_sync,
|
|
guild_id, channel_id, user_id, username, content, message_ts,
|
|
toxicity_score, drama_score, categories, reasoning,
|
|
off_topic, topic_category, topic_reasoning,
|
|
coherence_score, coherence_flag,
|
|
)
|
|
except Exception:
|
|
logger.exception("Failed to save message and analysis")
|
|
return None
|
|
|
|
def _save_message_and_analysis_sync(
|
|
self,
|
|
guild_id, channel_id, user_id, username, content, message_ts,
|
|
toxicity_score, drama_score, categories, reasoning,
|
|
off_topic, topic_category, topic_reasoning,
|
|
coherence_score, coherence_flag,
|
|
) -> int:
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(
|
|
"""INSERT INTO Messages (GuildId, ChannelId, UserId, Username, Content, MessageTs)
|
|
OUTPUT INSERTED.Id
|
|
VALUES (?, ?, ?, ?, ?, ?)""",
|
|
guild_id, channel_id, user_id, username,
|
|
content[:4000], # Truncate very long messages
|
|
message_ts,
|
|
)
|
|
msg_id = cursor.fetchone()[0]
|
|
|
|
cursor.execute(
|
|
"""INSERT INTO AnalysisResults
|
|
(MessageId, ToxicityScore, DramaScore, Categories, Reasoning,
|
|
OffTopic, TopicCategory, TopicReasoning,
|
|
CoherenceScore, CoherenceFlag)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
msg_id, toxicity_score, drama_score,
|
|
json.dumps(categories), reasoning[:4000],
|
|
1 if off_topic else 0,
|
|
topic_category, topic_reasoning[:4000] if topic_reasoning else None,
|
|
coherence_score, coherence_flag,
|
|
)
|
|
|
|
cursor.close()
|
|
return msg_id
|
|
finally:
|
|
conn.close()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Actions (fire-and-forget via asyncio.create_task)
|
|
# ------------------------------------------------------------------
|
|
async def save_action(
|
|
self,
|
|
guild_id: int,
|
|
user_id: int,
|
|
username: str,
|
|
action_type: str,
|
|
message_id: int | None = None,
|
|
details: str | None = None,
|
|
) -> None:
|
|
"""Save a moderation action (warning, mute, topic_remind, etc.)."""
|
|
if not self._available:
|
|
return
|
|
try:
|
|
await asyncio.to_thread(
|
|
self._save_action_sync,
|
|
guild_id, user_id, username, action_type, message_id, details,
|
|
)
|
|
except Exception:
|
|
logger.exception("Failed to save action")
|
|
|
|
def _save_action_sync(self, guild_id, user_id, username, action_type, message_id, details):
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""INSERT INTO Actions (GuildId, UserId, Username, ActionType, MessageId, Details)
|
|
VALUES (?, ?, ?, ?, ?, ?)""",
|
|
guild_id, user_id, username, action_type, message_id,
|
|
details[:4000] if details else None,
|
|
)
|
|
cursor.close()
|
|
finally:
|
|
conn.close()
|
|
|
|
# ------------------------------------------------------------------
|
|
# UserState (upsert via MERGE)
|
|
# ------------------------------------------------------------------
|
|
async def save_user_state(
|
|
self,
|
|
user_id: int,
|
|
offense_count: int,
|
|
immune: bool,
|
|
off_topic_count: int,
|
|
baseline_coherence: float = 0.85,
|
|
user_notes: str | None = None,
|
|
warned: bool = False,
|
|
last_offense_at: float | None = None,
|
|
aliases: str | None = None,
|
|
warning_expires_at: float | None = None,
|
|
) -> None:
|
|
"""Upsert user state (offense count, immunity, off-topic count, coherence baseline, notes, warned, last offense time, aliases, warning expiration)."""
|
|
if not self._available:
|
|
return
|
|
try:
|
|
await asyncio.to_thread(
|
|
self._save_user_state_sync,
|
|
user_id, offense_count, immune, off_topic_count, baseline_coherence, user_notes, warned, last_offense_at, aliases, warning_expires_at,
|
|
)
|
|
except Exception:
|
|
logger.exception("Failed to save user state")
|
|
|
|
def _save_user_state_sync(self, user_id, offense_count, immune, off_topic_count, baseline_coherence, user_notes, warned, last_offense_at, aliases, warning_expires_at):
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""MERGE UserState AS target
|
|
USING (SELECT ? AS UserId) AS source
|
|
ON target.UserId = source.UserId
|
|
WHEN MATCHED THEN
|
|
UPDATE SET OffenseCount = ?, Immune = ?, OffTopicCount = ?,
|
|
BaselineCoherence = ?, UserNotes = ?, Warned = ?,
|
|
LastOffenseAt = ?, Aliases = ?, WarningExpiresAt = ?,
|
|
UpdatedAt = SYSUTCDATETIME()
|
|
WHEN NOT MATCHED THEN
|
|
INSERT (UserId, OffenseCount, Immune, OffTopicCount, BaselineCoherence, UserNotes, Warned, LastOffenseAt, Aliases, WarningExpiresAt)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);""",
|
|
user_id,
|
|
offense_count, 1 if immune else 0, off_topic_count, baseline_coherence, user_notes, 1 if warned else 0, last_offense_at, aliases, warning_expires_at,
|
|
user_id, offense_count, 1 if immune else 0, off_topic_count, baseline_coherence, user_notes, 1 if warned else 0, last_offense_at, aliases, warning_expires_at,
|
|
)
|
|
cursor.close()
|
|
finally:
|
|
conn.close()
|
|
|
|
async def delete_user_state(self, user_id: int) -> None:
|
|
"""Remove a user's persisted state (used by /bcs-reset)."""
|
|
if not self._available:
|
|
return
|
|
try:
|
|
await asyncio.to_thread(self._delete_user_state_sync, user_id)
|
|
except Exception:
|
|
logger.exception("Failed to delete user state")
|
|
|
|
def _delete_user_state_sync(self, user_id):
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM UserState WHERE UserId = ?", user_id)
|
|
cursor.close()
|
|
finally:
|
|
conn.close()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Hydration (load all user states on startup)
|
|
# ------------------------------------------------------------------
|
|
async def load_all_user_states(self) -> list[dict]:
|
|
"""Load all user states from the database for startup hydration.
|
|
Returns list of dicts with user_id, offense_count, immune, off_topic_count."""
|
|
if not self._available:
|
|
return []
|
|
try:
|
|
return await asyncio.to_thread(self._load_all_user_states_sync)
|
|
except Exception:
|
|
logger.exception("Failed to load user states")
|
|
return []
|
|
|
|
def _load_all_user_states_sync(self) -> list[dict]:
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT UserId, OffenseCount, Immune, OffTopicCount, BaselineCoherence, UserNotes, Warned, LastOffenseAt, Aliases, WarningExpiresAt FROM UserState"
|
|
)
|
|
rows = cursor.fetchall()
|
|
cursor.close()
|
|
return [
|
|
{
|
|
"user_id": row[0],
|
|
"offense_count": row[1],
|
|
"immune": bool(row[2]),
|
|
"off_topic_count": row[3],
|
|
"baseline_coherence": float(row[4]),
|
|
"user_notes": row[5] or "",
|
|
"warned": bool(row[6]),
|
|
"last_offense_at": float(row[7]) if row[7] is not None else 0.0,
|
|
"aliases": row[8] or "",
|
|
"warning_expires_at": float(row[9]) if row[9] is not None else 0.0,
|
|
}
|
|
for row in rows
|
|
]
|
|
finally:
|
|
conn.close()
|
|
|
|
# ------------------------------------------------------------------
|
|
# LLM Log (fire-and-forget via asyncio.create_task)
|
|
# ------------------------------------------------------------------
|
|
async def save_llm_log(
|
|
self,
|
|
request_type: str,
|
|
model: str,
|
|
duration_ms: int,
|
|
success: bool,
|
|
request: str,
|
|
response: str | None = None,
|
|
error: str | None = None,
|
|
input_tokens: int | None = None,
|
|
output_tokens: int | None = None,
|
|
) -> None:
|
|
"""Save an LLM request/response log entry."""
|
|
if not self._available:
|
|
return
|
|
try:
|
|
await asyncio.to_thread(
|
|
self._save_llm_log_sync,
|
|
request_type, model, duration_ms, success, request,
|
|
response, error, input_tokens, output_tokens,
|
|
)
|
|
except Exception:
|
|
logger.exception("Failed to save LLM log")
|
|
|
|
def _save_llm_log_sync(
|
|
self, request_type, model, duration_ms, success, request,
|
|
response, error, input_tokens, output_tokens,
|
|
):
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""INSERT INTO LlmLog
|
|
(RequestType, Model, InputTokens, OutputTokens, DurationMs,
|
|
Success, Request, Response, Error)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
request_type, model, input_tokens, output_tokens, duration_ms,
|
|
1 if success else 0,
|
|
request[:4000] if request else "",
|
|
response[:4000] if response else None,
|
|
error[:4000] if error else None,
|
|
)
|
|
cursor.close()
|
|
finally:
|
|
conn.close()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Bot Settings (key-value store)
|
|
# ------------------------------------------------------------------
|
|
async def save_setting(self, key: str, value: str) -> None:
|
|
if not self._available:
|
|
return
|
|
try:
|
|
await asyncio.to_thread(self._save_setting_sync, key, value)
|
|
except Exception:
|
|
logger.exception("Failed to save setting %s", key)
|
|
|
|
def _save_setting_sync(self, key: str, value: str):
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""MERGE BotSettings AS target
|
|
USING (SELECT ? AS SettingKey) AS source
|
|
ON target.SettingKey = source.SettingKey
|
|
WHEN MATCHED THEN
|
|
UPDATE SET SettingValue = ?, UpdatedAt = SYSUTCDATETIME()
|
|
WHEN NOT MATCHED THEN
|
|
INSERT (SettingKey, SettingValue) VALUES (?, ?);""",
|
|
key, value, key, value,
|
|
)
|
|
cursor.close()
|
|
finally:
|
|
conn.close()
|
|
|
|
async def load_setting(self, key: str, default: str | None = None) -> str | None:
|
|
if not self._available:
|
|
return default
|
|
try:
|
|
return await asyncio.to_thread(self._load_setting_sync, key, default)
|
|
except Exception:
|
|
logger.exception("Failed to load setting %s", key)
|
|
return default
|
|
|
|
def _load_setting_sync(self, key: str, default: str | None) -> str | None:
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT SettingValue FROM BotSettings WHERE SettingKey = ?", key
|
|
)
|
|
row = cursor.fetchone()
|
|
cursor.close()
|
|
return row[0] if row else default
|
|
finally:
|
|
conn.close()
|
|
|
|
# ------------------------------------------------------------------
|
|
# UserMemory (conversational memory per user)
|
|
# ------------------------------------------------------------------
|
|
async def save_memory(
|
|
self,
|
|
user_id: int,
|
|
memory: str,
|
|
topics: str,
|
|
importance: str,
|
|
expires_at: datetime,
|
|
source: str,
|
|
) -> None:
|
|
"""Insert a single memory row for a user."""
|
|
if not self._available:
|
|
return
|
|
try:
|
|
await asyncio.to_thread(
|
|
self._save_memory_sync,
|
|
user_id, memory, topics, importance, expires_at, source,
|
|
)
|
|
except Exception:
|
|
logger.exception("Failed to save memory")
|
|
|
|
def _save_memory_sync(self, user_id, memory, topics, importance, expires_at, source):
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
# Skip if an identical memory already exists for this user
|
|
cursor.execute(
|
|
"SELECT COUNT(*) FROM UserMemory WHERE UserId = ? AND Memory = ?",
|
|
user_id, memory[:500],
|
|
)
|
|
if cursor.fetchone()[0] > 0:
|
|
cursor.close()
|
|
return
|
|
cursor.execute(
|
|
"""INSERT INTO UserMemory (UserId, Memory, Topics, Importance, ExpiresAt, Source)
|
|
VALUES (?, ?, ?, ?, ?, ?)""",
|
|
user_id,
|
|
memory[:500],
|
|
topics[:200],
|
|
importance[:10],
|
|
expires_at,
|
|
source[:20],
|
|
)
|
|
cursor.close()
|
|
finally:
|
|
conn.close()
|
|
|
|
async def get_recent_memories(self, user_id: int, limit: int = 5) -> list[dict]:
|
|
"""Get the N most recent non-expired memories for a user."""
|
|
if not self._available:
|
|
return []
|
|
try:
|
|
return await asyncio.to_thread(self._get_recent_memories_sync, user_id, limit)
|
|
except Exception:
|
|
logger.exception("Failed to get recent memories")
|
|
return []
|
|
|
|
def _get_recent_memories_sync(self, user_id, limit) -> list[dict]:
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""SELECT TOP (?) Memory, Topics, Importance, CreatedAt
|
|
FROM UserMemory
|
|
WHERE UserId = ? AND ExpiresAt > SYSUTCDATETIME()
|
|
ORDER BY CreatedAt DESC""",
|
|
limit, user_id,
|
|
)
|
|
rows = cursor.fetchall()
|
|
cursor.close()
|
|
return [
|
|
{
|
|
"memory": row[0],
|
|
"topics": row[1],
|
|
"importance": row[2],
|
|
"created_at": row[3],
|
|
}
|
|
for row in rows
|
|
]
|
|
finally:
|
|
conn.close()
|
|
|
|
async def get_memories_by_topics(self, user_id: int, topic_keywords: list[str], limit: int = 5) -> list[dict]:
|
|
"""Get non-expired memories matching any of the given topic keywords via LIKE."""
|
|
if not self._available:
|
|
return []
|
|
try:
|
|
return await asyncio.to_thread(
|
|
self._get_memories_by_topics_sync, user_id, topic_keywords, limit,
|
|
)
|
|
except Exception:
|
|
logger.exception("Failed to get memories by topics")
|
|
return []
|
|
|
|
def _get_memories_by_topics_sync(self, user_id, topic_keywords, limit) -> list[dict]:
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
if not topic_keywords:
|
|
cursor.close()
|
|
return []
|
|
# Build OR conditions for each keyword
|
|
conditions = " OR ".join(["Topics LIKE ?" for _ in topic_keywords])
|
|
escaped = [kw.replace("%", "[%]").replace("_", "[_]") for kw in topic_keywords]
|
|
params = [limit, user_id] + [f"%{kw}%" for kw in escaped]
|
|
cursor.execute(
|
|
f"""SELECT TOP (?) Memory, Topics, Importance, CreatedAt
|
|
FROM UserMemory
|
|
WHERE UserId = ? AND ExpiresAt > SYSUTCDATETIME()
|
|
AND ({conditions})
|
|
ORDER BY
|
|
CASE Importance
|
|
WHEN 'high' THEN 1
|
|
WHEN 'medium' THEN 2
|
|
WHEN 'low' THEN 3
|
|
ELSE 4
|
|
END,
|
|
CreatedAt DESC""",
|
|
*params,
|
|
)
|
|
rows = cursor.fetchall()
|
|
cursor.close()
|
|
return [
|
|
{
|
|
"memory": row[0],
|
|
"topics": row[1],
|
|
"importance": row[2],
|
|
"created_at": row[3],
|
|
}
|
|
for row in rows
|
|
]
|
|
finally:
|
|
conn.close()
|
|
|
|
async def prune_expired_memories(self) -> int:
|
|
"""Delete all expired memories. Returns count deleted."""
|
|
if not self._available:
|
|
return 0
|
|
try:
|
|
return await asyncio.to_thread(self._prune_expired_memories_sync)
|
|
except Exception:
|
|
logger.exception("Failed to prune expired memories")
|
|
return 0
|
|
|
|
def _prune_expired_memories_sync(self) -> int:
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM UserMemory WHERE ExpiresAt < SYSUTCDATETIME()")
|
|
count = cursor.rowcount
|
|
cursor.close()
|
|
return count
|
|
finally:
|
|
conn.close()
|
|
|
|
async def prune_excess_memories(self, user_id: int, max_memories: int = 50) -> int:
|
|
"""Delete excess memories for a user beyond the cap, keeping high importance and newest first.
|
|
Returns count deleted."""
|
|
if not self._available:
|
|
return 0
|
|
try:
|
|
return await asyncio.to_thread(self._prune_excess_memories_sync, user_id, max_memories)
|
|
except Exception:
|
|
logger.exception("Failed to prune excess memories")
|
|
return 0
|
|
|
|
def _prune_excess_memories_sync(self, user_id, max_memories) -> int:
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""DELETE FROM UserMemory
|
|
WHERE Id IN (
|
|
SELECT Id FROM (
|
|
SELECT Id, ROW_NUMBER() OVER (
|
|
ORDER BY
|
|
CASE Importance
|
|
WHEN 'high' THEN 1
|
|
WHEN 'medium' THEN 2
|
|
WHEN 'low' THEN 3
|
|
ELSE 4
|
|
END,
|
|
CreatedAt DESC
|
|
) AS rn
|
|
FROM UserMemory
|
|
WHERE UserId = ?
|
|
) ranked
|
|
WHERE rn > ?
|
|
)""",
|
|
user_id, max_memories,
|
|
)
|
|
count = cursor.rowcount
|
|
cursor.close()
|
|
return count
|
|
finally:
|
|
conn.close()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Drama Leaderboard (historical stats from Messages + AnalysisResults + Actions)
|
|
# ------------------------------------------------------------------
|
|
async def get_drama_leaderboard(self, guild_id: int, days: int | None = None) -> list[dict]:
|
|
"""Get per-user drama stats for the leaderboard.
|
|
days=None means all-time. Returns list of dicts sorted by user_id."""
|
|
if not self._available:
|
|
return []
|
|
try:
|
|
return await asyncio.to_thread(self._get_drama_leaderboard_sync, guild_id, days)
|
|
except Exception:
|
|
logger.exception("Failed to get drama leaderboard")
|
|
return []
|
|
|
|
def _get_drama_leaderboard_sync(self, guild_id: int, days: int | None) -> list[dict]:
|
|
conn = self._connect()
|
|
try:
|
|
cursor = conn.cursor()
|
|
|
|
date_filter = ""
|
|
params: list = [guild_id]
|
|
if days is not None:
|
|
date_filter = "AND m.CreatedAt >= DATEADD(DAY, ?, SYSUTCDATETIME())"
|
|
params.append(-days)
|
|
|
|
# Analysis stats from Messages + AnalysisResults
|
|
cursor.execute(f"""
|
|
SELECT
|
|
m.UserId,
|
|
MAX(m.Username) AS Username,
|
|
AVG(ar.ToxicityScore) AS AvgToxicity,
|
|
MAX(ar.ToxicityScore) AS MaxToxicity,
|
|
COUNT(*) AS MessagesAnalyzed
|
|
FROM Messages m
|
|
INNER JOIN AnalysisResults ar ON ar.MessageId = m.Id
|
|
WHERE m.GuildId = ? {date_filter}
|
|
GROUP BY m.UserId
|
|
""", *params)
|
|
|
|
analysis_rows = cursor.fetchall()
|
|
|
|
# Action counts
|
|
action_date_filter = ""
|
|
action_params: list = [guild_id]
|
|
if days is not None:
|
|
action_date_filter = "AND CreatedAt >= DATEADD(DAY, ?, SYSUTCDATETIME())"
|
|
action_params.append(-days)
|
|
|
|
cursor.execute(f"""
|
|
SELECT
|
|
UserId,
|
|
SUM(CASE WHEN ActionType = 'warning' THEN 1 ELSE 0 END) AS Warnings,
|
|
SUM(CASE WHEN ActionType = 'mute' THEN 1 ELSE 0 END) AS Mutes,
|
|
SUM(CASE WHEN ActionType IN ('topic_remind', 'topic_nudge') THEN 1 ELSE 0 END) AS OffTopic
|
|
FROM Actions
|
|
WHERE GuildId = ? {action_date_filter}
|
|
GROUP BY UserId
|
|
""", *action_params)
|
|
|
|
action_map = {}
|
|
for row in cursor.fetchall():
|
|
action_map[row[0]] = {
|
|
"warnings": row[1],
|
|
"mutes": row[2],
|
|
"off_topic": row[3],
|
|
}
|
|
|
|
cursor.close()
|
|
|
|
results = []
|
|
for row in analysis_rows:
|
|
user_id = row[0]
|
|
actions = action_map.get(user_id, {"warnings": 0, "mutes": 0, "off_topic": 0})
|
|
results.append({
|
|
"user_id": user_id,
|
|
"username": row[1],
|
|
"avg_toxicity": float(row[2]),
|
|
"max_toxicity": float(row[3]),
|
|
"messages_analyzed": row[4],
|
|
"warnings": actions["warnings"],
|
|
"mutes": actions["mutes"],
|
|
"off_topic": actions["off_topic"],
|
|
})
|
|
|
|
return results
|
|
finally:
|
|
conn.close()
|
|
|
|
async def close(self):
|
|
"""No persistent connection to close (connections are per-operation)."""
|
|
pass
|