clawdbot-workspace/memory-retrieval.py
2026-01-28 23:00:58 -05:00

486 lines
15 KiB
Python

#!/usr/bin/env python3
"""
Memory retrieval functions for Clawdbot.
Supports FTS5 keyword search, with guild scoping and recency bias.
Vector search requires the vec0 extension loaded at runtime.
"""
import sqlite3
import os
import stat
import json
import sys
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
DB_PATH = os.path.expanduser("~/.clawdbot/memory/main.sqlite")
DB_DIR = os.path.dirname(DB_PATH)
# Security: Required permissions
SECURE_FILE_MODE = 0o600 # Owner read/write only
SECURE_DIR_MODE = 0o700 # Owner read/write/execute only
def ensure_secure_permissions(warn: bool = True) -> List[str]:
"""
Check and auto-fix permissions on database and directory.
Returns list of fixes applied. Prints warnings if warn=True.
Self-healing: automatically corrects insecure permissions.
"""
fixes = []
# Check directory permissions
if os.path.exists(DB_DIR):
current_mode = stat.S_IMODE(os.stat(DB_DIR).st_mode)
if current_mode != SECURE_DIR_MODE:
os.chmod(DB_DIR, SECURE_DIR_MODE)
msg = f"[SECURITY] Fixed directory permissions: {DB_DIR} ({oct(current_mode)} -> {oct(SECURE_DIR_MODE)})"
fixes.append(msg)
if warn:
print(msg, file=sys.stderr)
# Check database file permissions
if os.path.exists(DB_PATH):
current_mode = stat.S_IMODE(os.stat(DB_PATH).st_mode)
if current_mode != SECURE_FILE_MODE:
os.chmod(DB_PATH, SECURE_FILE_MODE)
msg = f"[SECURITY] Fixed database permissions: {DB_PATH} ({oct(current_mode)} -> {oct(SECURE_FILE_MODE)})"
fixes.append(msg)
if warn:
print(msg, file=sys.stderr)
# Check any other sqlite files in the directory
if os.path.exists(DB_DIR):
for filename in os.listdir(DB_DIR):
if filename.endswith('.sqlite'):
filepath = os.path.join(DB_DIR, filename)
current_mode = stat.S_IMODE(os.stat(filepath).st_mode)
if current_mode != SECURE_FILE_MODE:
os.chmod(filepath, SECURE_FILE_MODE)
msg = f"[SECURITY] Fixed permissions: {filepath} ({oct(current_mode)} -> {oct(SECURE_FILE_MODE)})"
fixes.append(msg)
if warn:
print(msg, file=sys.stderr)
return fixes
def create_db_if_needed() -> bool:
"""
Create database directory and file with secure permissions if they don't exist.
Returns True if database was created, False if it already existed.
"""
created = False
# Create directory with secure permissions
if not os.path.exists(DB_DIR):
os.makedirs(DB_DIR, mode=SECURE_DIR_MODE)
print(f"[SECURITY] Created directory with secure permissions: {DB_DIR}", file=sys.stderr)
created = True
# Create database with secure permissions
if not os.path.exists(DB_PATH):
# Create empty database
conn = sqlite3.connect(DB_PATH)
conn.close()
# Set secure permissions immediately
os.chmod(DB_PATH, SECURE_FILE_MODE)
print(f"[SECURITY] Created database with secure permissions: {DB_PATH}", file=sys.stderr)
created = True
return created
def get_db():
"""Get database connection with automatic security checks."""
# Self-healing: check and fix permissions every time
ensure_secure_permissions(warn=True)
if not os.path.exists(DB_PATH):
raise FileNotFoundError(f"Memory database not found: {DB_PATH}")
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def escape_fts5_query(query: str) -> str:
"""
Escape a query string for safe use in FTS5 MATCH.
FTS5 special characters that need handling:
- Quotes (" and ') have special meaning
- Hyphens (-) mean NOT operator
- Other operators: AND, OR, NOT, NEAR, *
Strategy: wrap each word in double quotes to treat as literal phrase.
"""
if not query or not query.strip():
return None
# Split into words and quote each one to treat as literal
# This handles hyphens, apostrophes, and other special chars
words = query.split()
# Escape any double quotes within words, then wrap in quotes
escaped_words = []
for word in words:
# Replace double quotes with escaped double quotes for FTS5
escaped = word.replace('"', '""')
escaped_words.append(f'"{escaped}"')
return " ".join(escaped_words)
def search_memories(
query: str,
guild_id: Optional[str] = None,
user_id: Optional[str] = None,
memory_type: Optional[str] = None,
days_back: int = 90,
limit: int = 10
) -> List[Dict[str, Any]]:
"""
Search memories using FTS5 with optional filters.
Args:
query: Search query (plain text, automatically escaped for safety).
Special characters like hyphens and quotes are treated as literals.
guild_id: Filter to specific guild (None = all guilds)
user_id: Filter to specific user
memory_type: Filter to type (fact, preference, event, relationship, project, person, security)
days_back: Only search memories from last N days (0 = all time)
limit: Max results to return
Returns:
List of matching memories with scores
"""
# Validate and escape the query
safe_query = escape_fts5_query(query)
if safe_query is None:
return [] # Return empty results for empty/whitespace queries
db = get_db()
try:
# Build WHERE clause
conditions = ["memories_fts MATCH ?"]
params = [safe_query]
if guild_id:
conditions.append("(m.guild_id = ? OR m.guild_id IS NULL)")
params.append(guild_id)
if user_id:
conditions.append("m.user_id = ?")
params.append(user_id)
if memory_type:
conditions.append("m.memory_type = ?")
params.append(memory_type)
if days_back > 0:
cutoff = int((datetime.now() - timedelta(days=days_back)).timestamp())
conditions.append("m.created_at > ?")
params.append(cutoff)
where_clause = " AND ".join(conditions)
params.append(limit)
sql = f"""
SELECT
m.id,
m.content,
m.summary,
m.memory_type,
m.guild_id,
m.user_id,
m.confidence,
m.created_at,
m.access_count,
bm25(memories_fts) as fts_score
FROM memories m
JOIN memories_fts fts ON m.id = fts.rowid
WHERE {where_clause}
AND m.superseded_by IS NULL
ORDER BY
bm25(memories_fts),
m.confidence DESC,
m.created_at DESC
LIMIT ?
"""
cursor = db.execute(sql, params)
results = []
for row in cursor:
results.append({
"id": row["id"],
"content": row["content"],
"summary": row["summary"],
"memory_type": row["memory_type"],
"guild_id": row["guild_id"],
"user_id": row["user_id"],
"confidence": row["confidence"],
"created_at": row["created_at"],
"access_count": row["access_count"],
"fts_score": row["fts_score"],
})
# Update access counts
if results:
ids = [r["id"] for r in results]
placeholders = ",".join("?" * len(ids))
db.execute(f"""
UPDATE memories
SET last_accessed = ?, access_count = access_count + 1
WHERE id IN ({placeholders})
""", [int(datetime.now().timestamp())] + ids)
db.commit()
return results
finally:
db.close()
def get_recent_memories(
guild_id: Optional[str] = None,
limit: int = 10,
memory_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Get most recent memories, optionally filtered by guild.
Useful for context loading without a specific query.
"""
db = get_db()
try:
conditions = ["superseded_by IS NULL"]
params = []
if guild_id:
conditions.append("(guild_id = ? OR guild_id IS NULL)")
params.append(guild_id)
if memory_type:
conditions.append("memory_type = ?")
params.append(memory_type)
where_clause = " AND ".join(conditions)
params.append(limit)
sql = f"""
SELECT
id, content, summary, memory_type, guild_id,
user_id, confidence, created_at, access_count
FROM memories
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT ?
"""
cursor = db.execute(sql, params)
results = [dict(row) for row in cursor]
return results
finally:
db.close()
def add_memory(
content: str,
memory_type: str = "fact",
guild_id: Optional[str] = None,
channel_id: Optional[str] = None,
user_id: Optional[str] = None,
summary: Optional[str] = None,
source: str = "explicit",
confidence: float = 1.0
) -> int:
"""
Add a new memory to the database.
Returns:
ID of the created memory
"""
if not content or not content.strip():
raise ValueError("Content cannot be empty")
if not 0.0 <= confidence <= 1.0:
raise ValueError("Confidence must be between 0.0 and 1.0")
db = get_db()
try:
cursor = db.execute("""
INSERT INTO memories (
content, memory_type, guild_id, channel_id, user_id,
summary, source, confidence, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
content, memory_type, guild_id, channel_id, user_id,
summary, source, confidence, int(datetime.now().timestamp())
))
memory_id = cursor.lastrowid
db.commit()
return memory_id
finally:
db.close()
def supersede_memory(old_id: int, new_content: str, reason: str = "updated") -> int:
"""
Create a new memory that supersedes an old one.
The old memory is not deleted, just marked as superseded.
Returns:
ID of the new memory
"""
if not new_content or not new_content.strip():
raise ValueError("New content cannot be empty")
db = get_db()
try:
# Get the old memory's metadata
cursor = db.execute("""
SELECT memory_type, guild_id, channel_id, user_id, source
FROM memories WHERE id = ?
""", (old_id,))
old = cursor.fetchone()
if not old:
raise ValueError(f"Memory {old_id} not found")
# Create new memory and mark old as superseded in same transaction
cursor = db.execute("""
INSERT INTO memories (
content, memory_type, guild_id, channel_id, user_id,
summary, source, confidence, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, 1.0, ?)
""", (
new_content, old["memory_type"], old["guild_id"], old["channel_id"],
old["user_id"], f"Updated: {reason}", old["source"],
int(datetime.now().timestamp())
))
new_id = cursor.lastrowid
# Mark old as superseded
db.execute("UPDATE memories SET superseded_by = ? WHERE id = ?", (new_id, old_id))
db.commit()
return new_id
except Exception:
db.rollback()
raise
finally:
db.close()
def get_memory_stats() -> Dict[str, Any]:
"""Get statistics about the memory database."""
db = get_db()
try:
stats = {}
# Total counts
cursor = db.execute("SELECT COUNT(*) FROM memories WHERE superseded_by IS NULL")
stats["total_active"] = cursor.fetchone()[0]
cursor = db.execute("SELECT COUNT(*) FROM memories WHERE superseded_by IS NOT NULL")
stats["total_superseded"] = cursor.fetchone()[0]
# By type
cursor = db.execute("""
SELECT memory_type, COUNT(*)
FROM memories
WHERE superseded_by IS NULL
GROUP BY memory_type
""")
stats["by_type"] = {row[0]: row[1] for row in cursor}
# By guild
cursor = db.execute("""
SELECT
COALESCE(guild_id, 'global') as guild,
COUNT(*)
FROM memories
WHERE superseded_by IS NULL
GROUP BY guild_id
""")
stats["by_guild"] = {row[0]: row[1] for row in cursor}
return stats
finally:
db.close()
# CLI interface for testing
if __name__ == "__main__":
import sys
def get_cli_arg(flag: str, default=None):
"""Safely get CLI argument value after a flag."""
if flag not in sys.argv:
return default
idx = sys.argv.index(flag)
if idx + 1 >= len(sys.argv):
print(f"Error: {flag} requires a value")
sys.exit(1)
return sys.argv[idx + 1]
def safe_truncate(text: str, length: int = 100) -> str:
"""Safely truncate text, handling None values."""
if text is None:
return "(empty)"
text = text.replace('\n', ' ') # Single line
if len(text) <= length:
return text
return text[:length] + "..."
if len(sys.argv) < 2:
print("Usage:")
print(" python memory-retrieval.py search <query> [--guild <id>]")
print(" python memory-retrieval.py recent [--guild <id>] [--limit <n>]")
print(" python memory-retrieval.py stats")
print(" python memory-retrieval.py add <content> [--type <type>] [--guild <id>]")
sys.exit(1)
cmd = sys.argv[1]
if cmd == "search" and len(sys.argv) >= 3:
query = sys.argv[2]
guild_id = get_cli_arg("--guild")
results = search_memories(query, guild_id=guild_id)
print(f"Found {len(results)} memories:\n")
for r in results:
print(f"[{r['id']}] ({r['memory_type']}) {safe_truncate(r['content'])}")
print(f" Score: {r['fts_score']:.4f}, Confidence: {r['confidence']}")
print()
elif cmd == "recent":
guild_id = get_cli_arg("--guild")
limit_str = get_cli_arg("--limit", "10")
try:
limit = int(limit_str)
except ValueError:
print(f"Error: --limit must be a number, got '{limit_str}'")
sys.exit(1)
results = get_recent_memories(guild_id=guild_id, limit=limit)
print(f"Recent {len(results)} memories:\n")
for r in results:
print(f"[{r['id']}] ({r['memory_type']}) {safe_truncate(r['content'])}")
print()
elif cmd == "stats":
stats = get_memory_stats()
print(json.dumps(stats, indent=2))
elif cmd == "add" and len(sys.argv) >= 3:
content = sys.argv[2]
memory_type = get_cli_arg("--type", "fact")
guild_id = get_cli_arg("--guild")
memory_id = add_memory(content, memory_type=memory_type, guild_id=guild_id)
print(f"Added memory with ID: {memory_id}")
else:
print("Unknown command or missing arguments")
sys.exit(1)