feat: add UserMemory table and CRUD methods for conversational memory
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -164,6 +164,22 @@ class Database:
|
|||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'UserMemory')
|
||||||
|
CREATE TABLE UserMemory (
|
||||||
|
Id BIGINT IDENTITY(1,1) PRIMARY KEY,
|
||||||
|
UserId BIGINT NOT NULL,
|
||||||
|
Memory NVARCHAR(500) NOT NULL,
|
||||||
|
Topics NVARCHAR(200) NOT NULL,
|
||||||
|
Importance NVARCHAR(10) NOT NULL,
|
||||||
|
ExpiresAt DATETIME2 NOT NULL,
|
||||||
|
Source NVARCHAR(20) NOT NULL,
|
||||||
|
CreatedAt DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME(),
|
||||||
|
INDEX IX_UserMemory_UserId (UserId),
|
||||||
|
INDEX IX_UserMemory_ExpiresAt (ExpiresAt)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
def _parse_database_name(self) -> str:
|
def _parse_database_name(self) -> str:
|
||||||
@@ -491,6 +507,196 @@ class Database:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# UserMemory (conversational memory per user)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
async def save_memory(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
memory: str,
|
||||||
|
topics: str,
|
||||||
|
importance: str,
|
||||||
|
expires_at: datetime,
|
||||||
|
source: str,
|
||||||
|
) -> None:
|
||||||
|
"""Insert a single memory row for a user."""
|
||||||
|
if not self._available:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self._save_memory_sync,
|
||||||
|
user_id, memory, topics, importance, expires_at, source,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to save memory")
|
||||||
|
|
||||||
|
def _save_memory_sync(self, user_id, memory, topics, importance, expires_at, source):
|
||||||
|
conn = self._connect()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""INSERT INTO UserMemory (UserId, Memory, Topics, Importance, ExpiresAt, Source)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)""",
|
||||||
|
user_id,
|
||||||
|
memory[:500],
|
||||||
|
topics[:200],
|
||||||
|
importance[:10],
|
||||||
|
expires_at,
|
||||||
|
source[:20],
|
||||||
|
)
|
||||||
|
cursor.close()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
async def get_recent_memories(self, user_id: int, limit: int = 10) -> list[dict]:
|
||||||
|
"""Get the N most recent non-expired memories for a user."""
|
||||||
|
if not self._available:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(self._get_recent_memories_sync, user_id, limit)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get recent memories")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_recent_memories_sync(self, user_id, limit) -> list[dict]:
|
||||||
|
conn = self._connect()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""SELECT TOP (?) Memory, Topics, Importance, CreatedAt
|
||||||
|
FROM UserMemory
|
||||||
|
WHERE UserId = ? AND ExpiresAt > SYSUTCDATETIME()
|
||||||
|
ORDER BY CreatedAt DESC""",
|
||||||
|
limit, user_id,
|
||||||
|
)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"memory": row[0],
|
||||||
|
"topics": row[1],
|
||||||
|
"importance": row[2],
|
||||||
|
"created_at": row[3],
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
async def get_memories_by_topics(self, user_id: int, topic_keywords: list[str], limit: int = 10) -> list[dict]:
|
||||||
|
"""Get non-expired memories matching any of the given topic keywords via LIKE."""
|
||||||
|
if not self._available:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self._get_memories_by_topics_sync, user_id, topic_keywords, limit,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get memories by topics")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_memories_by_topics_sync(self, user_id, topic_keywords, limit) -> list[dict]:
|
||||||
|
conn = self._connect()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
if not topic_keywords:
|
||||||
|
cursor.close()
|
||||||
|
return []
|
||||||
|
# Build OR conditions for each keyword
|
||||||
|
conditions = " OR ".join(["Topics LIKE ?" for _ in topic_keywords])
|
||||||
|
params = [limit, user_id] + [f"%{kw}%" for kw in topic_keywords]
|
||||||
|
cursor.execute(
|
||||||
|
f"""SELECT TOP (?) Memory, Topics, Importance, CreatedAt
|
||||||
|
FROM UserMemory
|
||||||
|
WHERE UserId = ? AND ExpiresAt > SYSUTCDATETIME()
|
||||||
|
AND ({conditions})
|
||||||
|
ORDER BY
|
||||||
|
CASE Importance
|
||||||
|
WHEN 'high' THEN 1
|
||||||
|
WHEN 'medium' THEN 2
|
||||||
|
WHEN 'low' THEN 3
|
||||||
|
ELSE 4
|
||||||
|
END,
|
||||||
|
CreatedAt DESC""",
|
||||||
|
*params,
|
||||||
|
)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"memory": row[0],
|
||||||
|
"topics": row[1],
|
||||||
|
"importance": row[2],
|
||||||
|
"created_at": row[3],
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
async def prune_expired_memories(self) -> int:
|
||||||
|
"""Delete all expired memories. Returns count deleted."""
|
||||||
|
if not self._available:
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(self._prune_expired_memories_sync)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to prune expired memories")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _prune_expired_memories_sync(self) -> int:
|
||||||
|
conn = self._connect()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM UserMemory WHERE ExpiresAt < SYSUTCDATETIME()")
|
||||||
|
count = cursor.rowcount
|
||||||
|
cursor.close()
|
||||||
|
return count
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
async def prune_excess_memories(self, user_id: int, cap: int = 50) -> int:
|
||||||
|
"""Delete excess memories for a user beyond the cap, keeping high importance and newest first.
|
||||||
|
Returns count deleted."""
|
||||||
|
if not self._available:
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(self._prune_excess_memories_sync, user_id, cap)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to prune excess memories")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _prune_excess_memories_sync(self, user_id, cap) -> int:
|
||||||
|
conn = self._connect()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""DELETE FROM UserMemory
|
||||||
|
WHERE Id IN (
|
||||||
|
SELECT Id FROM (
|
||||||
|
SELECT Id, ROW_NUMBER() OVER (
|
||||||
|
ORDER BY
|
||||||
|
CASE Importance
|
||||||
|
WHEN 'high' THEN 1
|
||||||
|
WHEN 'medium' THEN 2
|
||||||
|
WHEN 'low' THEN 3
|
||||||
|
ELSE 4
|
||||||
|
END,
|
||||||
|
CreatedAt DESC
|
||||||
|
) AS rn
|
||||||
|
FROM UserMemory
|
||||||
|
WHERE UserId = ?
|
||||||
|
) ranked
|
||||||
|
WHERE rn > ?
|
||||||
|
)""",
|
||||||
|
user_id, cap,
|
||||||
|
)
|
||||||
|
count = cursor.rowcount
|
||||||
|
cursor.close()
|
||||||
|
return count
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""No persistent connection to close (connections are per-operation)."""
|
"""No persistent connection to close (connections are per-operation)."""
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user