diff --git a/utils/database.py b/utils/database.py index 0bf6214..42b237c 100644 --- a/utils/database.py +++ b/utils/database.py @@ -164,6 +164,22 @@ class Database: ) """) + cursor.execute(""" + IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'UserMemory') + CREATE TABLE UserMemory ( + Id BIGINT IDENTITY(1,1) PRIMARY KEY, + UserId BIGINT NOT NULL, + Memory NVARCHAR(500) NOT NULL, + Topics NVARCHAR(200) NOT NULL, + Importance NVARCHAR(10) NOT NULL, + ExpiresAt DATETIME2 NOT NULL, + Source NVARCHAR(20) NOT NULL, + CreatedAt DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME(), + INDEX IX_UserMemory_UserId (UserId), + INDEX IX_UserMemory_ExpiresAt (ExpiresAt) + ) + """) + cursor.close() def _parse_database_name(self) -> str: @@ -491,6 +507,196 @@ class Database: finally: conn.close() + # ------------------------------------------------------------------ + # UserMemory (conversational memory per user) + # ------------------------------------------------------------------ + async def save_memory( + self, + user_id: int, + memory: str, + topics: str, + importance: str, + expires_at: datetime, + source: str, + ) -> None: + """Insert a single memory row for a user.""" + if not self._available: + return + try: + await asyncio.to_thread( + self._save_memory_sync, + user_id, memory, topics, importance, expires_at, source, + ) + except Exception: + logger.exception("Failed to save memory") + + def _save_memory_sync(self, user_id, memory, topics, importance, expires_at, source): + conn = self._connect() + try: + cursor = conn.cursor() + cursor.execute( + """INSERT INTO UserMemory (UserId, Memory, Topics, Importance, ExpiresAt, Source) + VALUES (?, ?, ?, ?, ?, ?)""", + user_id, + memory[:500], + topics[:200], + importance[:10], + expires_at, + source[:20], + ) + cursor.close() + finally: + conn.close() + + async def get_recent_memories(self, user_id: int, limit: int = 10) -> list[dict]: + """Get the N most recent non-expired memories for a user.""" + if not self._available: + return [] + try: + return await asyncio.to_thread(self._get_recent_memories_sync, user_id, limit) + except Exception: + logger.exception("Failed to get recent memories") + return [] + + def _get_recent_memories_sync(self, user_id, limit) -> list[dict]: + conn = self._connect() + try: + cursor = conn.cursor() + cursor.execute( + """SELECT TOP (?) Memory, Topics, Importance, CreatedAt + FROM UserMemory + WHERE UserId = ? AND ExpiresAt > SYSUTCDATETIME() + ORDER BY CreatedAt DESC""", + limit, user_id, + ) + rows = cursor.fetchall() + cursor.close() + return [ + { + "memory": row[0], + "topics": row[1], + "importance": row[2], + "created_at": row[3], + } + for row in rows + ] + finally: + conn.close() + + async def get_memories_by_topics(self, user_id: int, topic_keywords: list[str], limit: int = 10) -> list[dict]: + """Get non-expired memories matching any of the given topic keywords via LIKE.""" + if not self._available: + return [] + try: + return await asyncio.to_thread( + self._get_memories_by_topics_sync, user_id, topic_keywords, limit, + ) + except Exception: + logger.exception("Failed to get memories by topics") + return [] + + def _get_memories_by_topics_sync(self, user_id, topic_keywords, limit) -> list[dict]: + conn = self._connect() + try: + cursor = conn.cursor() + if not topic_keywords: + cursor.close() + return [] + # Build OR conditions for each keyword + conditions = " OR ".join(["Topics LIKE ?" for _ in topic_keywords]) + params = [limit, user_id] + [f"%{kw}%" for kw in topic_keywords] + cursor.execute( + f"""SELECT TOP (?) Memory, Topics, Importance, CreatedAt + FROM UserMemory + WHERE UserId = ? AND ExpiresAt > SYSUTCDATETIME() + AND ({conditions}) + ORDER BY + CASE Importance + WHEN 'high' THEN 1 + WHEN 'medium' THEN 2 + WHEN 'low' THEN 3 + ELSE 4 + END, + CreatedAt DESC""", + *params, + ) + rows = cursor.fetchall() + cursor.close() + return [ + { + "memory": row[0], + "topics": row[1], + "importance": row[2], + "created_at": row[3], + } + for row in rows + ] + finally: + conn.close() + + async def prune_expired_memories(self) -> int: + """Delete all expired memories. Returns count deleted.""" + if not self._available: + return 0 + try: + return await asyncio.to_thread(self._prune_expired_memories_sync) + except Exception: + logger.exception("Failed to prune expired memories") + return 0 + + def _prune_expired_memories_sync(self) -> int: + conn = self._connect() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM UserMemory WHERE ExpiresAt < SYSUTCDATETIME()") + count = cursor.rowcount + cursor.close() + return count + finally: + conn.close() + + async def prune_excess_memories(self, user_id: int, cap: int = 50) -> int: + """Delete excess memories for a user beyond the cap, keeping high importance and newest first. + Returns count deleted.""" + if not self._available: + return 0 + try: + return await asyncio.to_thread(self._prune_excess_memories_sync, user_id, cap) + except Exception: + logger.exception("Failed to prune excess memories") + return 0 + + def _prune_excess_memories_sync(self, user_id, cap) -> int: + conn = self._connect() + try: + cursor = conn.cursor() + cursor.execute( + """DELETE FROM UserMemory + WHERE Id IN ( + SELECT Id FROM ( + SELECT Id, ROW_NUMBER() OVER ( + ORDER BY + CASE Importance + WHEN 'high' THEN 1 + WHEN 'medium' THEN 2 + WHEN 'low' THEN 3 + ELSE 4 + END, + CreatedAt DESC + ) AS rn + FROM UserMemory + WHERE UserId = ? + ) ranked + WHERE rn > ? + )""", + user_id, cap, + ) + count = cursor.rowcount + cursor.close() + return count + finally: + conn.close() + async def close(self): """No persistent connection to close (connections are per-operation).""" pass