feat: add UserMemory table and CRUD methods for conversational memory
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -164,6 +164,22 @@ class Database:
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'UserMemory')
|
||||
CREATE TABLE UserMemory (
|
||||
Id BIGINT IDENTITY(1,1) PRIMARY KEY,
|
||||
UserId BIGINT NOT NULL,
|
||||
Memory NVARCHAR(500) NOT NULL,
|
||||
Topics NVARCHAR(200) NOT NULL,
|
||||
Importance NVARCHAR(10) NOT NULL,
|
||||
ExpiresAt DATETIME2 NOT NULL,
|
||||
Source NVARCHAR(20) NOT NULL,
|
||||
CreatedAt DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME(),
|
||||
INDEX IX_UserMemory_UserId (UserId),
|
||||
INDEX IX_UserMemory_ExpiresAt (ExpiresAt)
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.close()
|
||||
|
||||
def _parse_database_name(self) -> str:
|
||||
@@ -491,6 +507,196 @@ class Database:
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# UserMemory (conversational memory per user)
|
||||
# ------------------------------------------------------------------
|
||||
async def save_memory(
|
||||
self,
|
||||
user_id: int,
|
||||
memory: str,
|
||||
topics: str,
|
||||
importance: str,
|
||||
expires_at: datetime,
|
||||
source: str,
|
||||
) -> None:
|
||||
"""Insert a single memory row for a user."""
|
||||
if not self._available:
|
||||
return
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self._save_memory_sync,
|
||||
user_id, memory, topics, importance, expires_at, source,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to save memory")
|
||||
|
||||
def _save_memory_sync(self, user_id, memory, topics, importance, expires_at, source):
|
||||
conn = self._connect()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""INSERT INTO UserMemory (UserId, Memory, Topics, Importance, ExpiresAt, Source)
|
||||
VALUES (?, ?, ?, ?, ?, ?)""",
|
||||
user_id,
|
||||
memory[:500],
|
||||
topics[:200],
|
||||
importance[:10],
|
||||
expires_at,
|
||||
source[:20],
|
||||
)
|
||||
cursor.close()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
async def get_recent_memories(self, user_id: int, limit: int = 10) -> list[dict]:
|
||||
"""Get the N most recent non-expired memories for a user."""
|
||||
if not self._available:
|
||||
return []
|
||||
try:
|
||||
return await asyncio.to_thread(self._get_recent_memories_sync, user_id, limit)
|
||||
except Exception:
|
||||
logger.exception("Failed to get recent memories")
|
||||
return []
|
||||
|
||||
def _get_recent_memories_sync(self, user_id, limit) -> list[dict]:
|
||||
conn = self._connect()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""SELECT TOP (?) Memory, Topics, Importance, CreatedAt
|
||||
FROM UserMemory
|
||||
WHERE UserId = ? AND ExpiresAt > SYSUTCDATETIME()
|
||||
ORDER BY CreatedAt DESC""",
|
||||
limit, user_id,
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
cursor.close()
|
||||
return [
|
||||
{
|
||||
"memory": row[0],
|
||||
"topics": row[1],
|
||||
"importance": row[2],
|
||||
"created_at": row[3],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
async def get_memories_by_topics(self, user_id: int, topic_keywords: list[str], limit: int = 10) -> list[dict]:
|
||||
"""Get non-expired memories matching any of the given topic keywords via LIKE."""
|
||||
if not self._available:
|
||||
return []
|
||||
try:
|
||||
return await asyncio.to_thread(
|
||||
self._get_memories_by_topics_sync, user_id, topic_keywords, limit,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to get memories by topics")
|
||||
return []
|
||||
|
||||
def _get_memories_by_topics_sync(self, user_id, topic_keywords, limit) -> list[dict]:
|
||||
conn = self._connect()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
if not topic_keywords:
|
||||
cursor.close()
|
||||
return []
|
||||
# Build OR conditions for each keyword
|
||||
conditions = " OR ".join(["Topics LIKE ?" for _ in topic_keywords])
|
||||
params = [limit, user_id] + [f"%{kw}%" for kw in topic_keywords]
|
||||
cursor.execute(
|
||||
f"""SELECT TOP (?) Memory, Topics, Importance, CreatedAt
|
||||
FROM UserMemory
|
||||
WHERE UserId = ? AND ExpiresAt > SYSUTCDATETIME()
|
||||
AND ({conditions})
|
||||
ORDER BY
|
||||
CASE Importance
|
||||
WHEN 'high' THEN 1
|
||||
WHEN 'medium' THEN 2
|
||||
WHEN 'low' THEN 3
|
||||
ELSE 4
|
||||
END,
|
||||
CreatedAt DESC""",
|
||||
*params,
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
cursor.close()
|
||||
return [
|
||||
{
|
||||
"memory": row[0],
|
||||
"topics": row[1],
|
||||
"importance": row[2],
|
||||
"created_at": row[3],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
async def prune_expired_memories(self) -> int:
|
||||
"""Delete all expired memories. Returns count deleted."""
|
||||
if not self._available:
|
||||
return 0
|
||||
try:
|
||||
return await asyncio.to_thread(self._prune_expired_memories_sync)
|
||||
except Exception:
|
||||
logger.exception("Failed to prune expired memories")
|
||||
return 0
|
||||
|
||||
def _prune_expired_memories_sync(self) -> int:
|
||||
conn = self._connect()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM UserMemory WHERE ExpiresAt < SYSUTCDATETIME()")
|
||||
count = cursor.rowcount
|
||||
cursor.close()
|
||||
return count
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
async def prune_excess_memories(self, user_id: int, cap: int = 50) -> int:
|
||||
"""Delete excess memories for a user beyond the cap, keeping high importance and newest first.
|
||||
Returns count deleted."""
|
||||
if not self._available:
|
||||
return 0
|
||||
try:
|
||||
return await asyncio.to_thread(self._prune_excess_memories_sync, user_id, cap)
|
||||
except Exception:
|
||||
logger.exception("Failed to prune excess memories")
|
||||
return 0
|
||||
|
||||
def _prune_excess_memories_sync(self, user_id, cap) -> int:
|
||||
conn = self._connect()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""DELETE FROM UserMemory
|
||||
WHERE Id IN (
|
||||
SELECT Id FROM (
|
||||
SELECT Id, ROW_NUMBER() OVER (
|
||||
ORDER BY
|
||||
CASE Importance
|
||||
WHEN 'high' THEN 1
|
||||
WHEN 'medium' THEN 2
|
||||
WHEN 'low' THEN 3
|
||||
ELSE 4
|
||||
END,
|
||||
CreatedAt DESC
|
||||
) AS rn
|
||||
FROM UserMemory
|
||||
WHERE UserId = ?
|
||||
) ranked
|
||||
WHERE rn > ?
|
||||
)""",
|
||||
user_id, cap,
|
||||
)
|
||||
count = cursor.rowcount
|
||||
cursor.close()
|
||||
return count
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
async def close(self):
|
||||
"""No persistent connection to close (connections are per-operation)."""
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user