feat: add UserMemory table and CRUD methods for conversational memory

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-26 12:48:54 -05:00
parent 333fbb3932
commit 75adafefd6

View File

@@ -164,6 +164,22 @@ class Database:
)
""")
cursor.execute("""
IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'UserMemory')
CREATE TABLE UserMemory (
Id BIGINT IDENTITY(1,1) PRIMARY KEY,
UserId BIGINT NOT NULL,
Memory NVARCHAR(500) NOT NULL,
Topics NVARCHAR(200) NOT NULL,
Importance NVARCHAR(10) NOT NULL,
ExpiresAt DATETIME2 NOT NULL,
Source NVARCHAR(20) NOT NULL,
CreatedAt DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME(),
INDEX IX_UserMemory_UserId (UserId),
INDEX IX_UserMemory_ExpiresAt (ExpiresAt)
)
""")
cursor.close()
def _parse_database_name(self) -> str:
@@ -491,6 +507,196 @@ class Database:
finally:
conn.close()
# ------------------------------------------------------------------
# UserMemory (conversational memory per user)
# ------------------------------------------------------------------
async def save_memory(
self,
user_id: int,
memory: str,
topics: str,
importance: str,
expires_at: datetime,
source: str,
) -> None:
"""Insert a single memory row for a user."""
if not self._available:
return
try:
await asyncio.to_thread(
self._save_memory_sync,
user_id, memory, topics, importance, expires_at, source,
)
except Exception:
logger.exception("Failed to save memory")
def _save_memory_sync(self, user_id, memory, topics, importance, expires_at, source):
conn = self._connect()
try:
cursor = conn.cursor()
cursor.execute(
"""INSERT INTO UserMemory (UserId, Memory, Topics, Importance, ExpiresAt, Source)
VALUES (?, ?, ?, ?, ?, ?)""",
user_id,
memory[:500],
topics[:200],
importance[:10],
expires_at,
source[:20],
)
cursor.close()
finally:
conn.close()
async def get_recent_memories(self, user_id: int, limit: int = 10) -> list[dict]:
"""Get the N most recent non-expired memories for a user."""
if not self._available:
return []
try:
return await asyncio.to_thread(self._get_recent_memories_sync, user_id, limit)
except Exception:
logger.exception("Failed to get recent memories")
return []
def _get_recent_memories_sync(self, user_id, limit) -> list[dict]:
conn = self._connect()
try:
cursor = conn.cursor()
cursor.execute(
"""SELECT TOP (?) Memory, Topics, Importance, CreatedAt
FROM UserMemory
WHERE UserId = ? AND ExpiresAt > SYSUTCDATETIME()
ORDER BY CreatedAt DESC""",
limit, user_id,
)
rows = cursor.fetchall()
cursor.close()
return [
{
"memory": row[0],
"topics": row[1],
"importance": row[2],
"created_at": row[3],
}
for row in rows
]
finally:
conn.close()
async def get_memories_by_topics(self, user_id: int, topic_keywords: list[str], limit: int = 10) -> list[dict]:
"""Get non-expired memories matching any of the given topic keywords via LIKE."""
if not self._available:
return []
try:
return await asyncio.to_thread(
self._get_memories_by_topics_sync, user_id, topic_keywords, limit,
)
except Exception:
logger.exception("Failed to get memories by topics")
return []
def _get_memories_by_topics_sync(self, user_id, topic_keywords, limit) -> list[dict]:
conn = self._connect()
try:
cursor = conn.cursor()
if not topic_keywords:
cursor.close()
return []
# Build OR conditions for each keyword
conditions = " OR ".join(["Topics LIKE ?" for _ in topic_keywords])
params = [limit, user_id] + [f"%{kw}%" for kw in topic_keywords]
cursor.execute(
f"""SELECT TOP (?) Memory, Topics, Importance, CreatedAt
FROM UserMemory
WHERE UserId = ? AND ExpiresAt > SYSUTCDATETIME()
AND ({conditions})
ORDER BY
CASE Importance
WHEN 'high' THEN 1
WHEN 'medium' THEN 2
WHEN 'low' THEN 3
ELSE 4
END,
CreatedAt DESC""",
*params,
)
rows = cursor.fetchall()
cursor.close()
return [
{
"memory": row[0],
"topics": row[1],
"importance": row[2],
"created_at": row[3],
}
for row in rows
]
finally:
conn.close()
async def prune_expired_memories(self) -> int:
"""Delete all expired memories. Returns count deleted."""
if not self._available:
return 0
try:
return await asyncio.to_thread(self._prune_expired_memories_sync)
except Exception:
logger.exception("Failed to prune expired memories")
return 0
def _prune_expired_memories_sync(self) -> int:
conn = self._connect()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM UserMemory WHERE ExpiresAt < SYSUTCDATETIME()")
count = cursor.rowcount
cursor.close()
return count
finally:
conn.close()
async def prune_excess_memories(self, user_id: int, cap: int = 50) -> int:
"""Delete excess memories for a user beyond the cap, keeping high importance and newest first.
Returns count deleted."""
if not self._available:
return 0
try:
return await asyncio.to_thread(self._prune_excess_memories_sync, user_id, cap)
except Exception:
logger.exception("Failed to prune excess memories")
return 0
def _prune_excess_memories_sync(self, user_id, cap) -> int:
conn = self._connect()
try:
cursor = conn.cursor()
cursor.execute(
"""DELETE FROM UserMemory
WHERE Id IN (
SELECT Id FROM (
SELECT Id, ROW_NUMBER() OVER (
ORDER BY
CASE Importance
WHEN 'high' THEN 1
WHEN 'medium' THEN 2
WHEN 'low' THEN 3
ELSE 4
END,
CreatedAt DESC
) AS rn
FROM UserMemory
WHERE UserId = ?
) ranked
WHERE rn > ?
)""",
user_id, cap,
)
count = cursor.rowcount
cursor.close()
return count
finally:
conn.close()
async def close(self):
"""No persistent connection to close (connections are per-operation)."""
pass