486 lines
15 KiB
Python
486 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Memory retrieval functions for Clawdbot.
|
|
Supports FTS5 keyword search, with guild scoping and recency bias.
|
|
Vector search requires the vec0 extension loaded at runtime.
|
|
"""
|
|
|
|
import sqlite3
|
|
import os
|
|
import stat
|
|
import json
|
|
import sys
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, List, Dict, Any
|
|
|
|
DB_PATH = os.path.expanduser("~/.clawdbot/memory/main.sqlite")
|
|
DB_DIR = os.path.dirname(DB_PATH)
|
|
|
|
# Security: Required permissions
|
|
SECURE_FILE_MODE = 0o600 # Owner read/write only
|
|
SECURE_DIR_MODE = 0o700 # Owner read/write/execute only
|
|
|
|
|
|
def ensure_secure_permissions(warn: bool = True) -> List[str]:
|
|
"""
|
|
Check and auto-fix permissions on database and directory.
|
|
|
|
Returns list of fixes applied. Prints warnings if warn=True.
|
|
Self-healing: automatically corrects insecure permissions.
|
|
"""
|
|
fixes = []
|
|
|
|
# Check directory permissions
|
|
if os.path.exists(DB_DIR):
|
|
current_mode = stat.S_IMODE(os.stat(DB_DIR).st_mode)
|
|
if current_mode != SECURE_DIR_MODE:
|
|
os.chmod(DB_DIR, SECURE_DIR_MODE)
|
|
msg = f"[SECURITY] Fixed directory permissions: {DB_DIR} ({oct(current_mode)} -> {oct(SECURE_DIR_MODE)})"
|
|
fixes.append(msg)
|
|
if warn:
|
|
print(msg, file=sys.stderr)
|
|
|
|
# Check database file permissions
|
|
if os.path.exists(DB_PATH):
|
|
current_mode = stat.S_IMODE(os.stat(DB_PATH).st_mode)
|
|
if current_mode != SECURE_FILE_MODE:
|
|
os.chmod(DB_PATH, SECURE_FILE_MODE)
|
|
msg = f"[SECURITY] Fixed database permissions: {DB_PATH} ({oct(current_mode)} -> {oct(SECURE_FILE_MODE)})"
|
|
fixes.append(msg)
|
|
if warn:
|
|
print(msg, file=sys.stderr)
|
|
|
|
# Check any other sqlite files in the directory
|
|
if os.path.exists(DB_DIR):
|
|
for filename in os.listdir(DB_DIR):
|
|
if filename.endswith('.sqlite'):
|
|
filepath = os.path.join(DB_DIR, filename)
|
|
current_mode = stat.S_IMODE(os.stat(filepath).st_mode)
|
|
if current_mode != SECURE_FILE_MODE:
|
|
os.chmod(filepath, SECURE_FILE_MODE)
|
|
msg = f"[SECURITY] Fixed permissions: {filepath} ({oct(current_mode)} -> {oct(SECURE_FILE_MODE)})"
|
|
fixes.append(msg)
|
|
if warn:
|
|
print(msg, file=sys.stderr)
|
|
|
|
return fixes
|
|
|
|
|
|
def create_db_if_needed() -> bool:
|
|
"""
|
|
Create database directory and file with secure permissions if they don't exist.
|
|
|
|
Returns True if database was created, False if it already existed.
|
|
"""
|
|
created = False
|
|
|
|
# Create directory with secure permissions
|
|
if not os.path.exists(DB_DIR):
|
|
os.makedirs(DB_DIR, mode=SECURE_DIR_MODE)
|
|
print(f"[SECURITY] Created directory with secure permissions: {DB_DIR}", file=sys.stderr)
|
|
created = True
|
|
|
|
# Create database with secure permissions
|
|
if not os.path.exists(DB_PATH):
|
|
# Create empty database
|
|
conn = sqlite3.connect(DB_PATH)
|
|
conn.close()
|
|
# Set secure permissions immediately
|
|
os.chmod(DB_PATH, SECURE_FILE_MODE)
|
|
print(f"[SECURITY] Created database with secure permissions: {DB_PATH}", file=sys.stderr)
|
|
created = True
|
|
|
|
return created
|
|
|
|
|
|
def get_db():
|
|
"""Get database connection with automatic security checks."""
|
|
# Self-healing: check and fix permissions every time
|
|
ensure_secure_permissions(warn=True)
|
|
|
|
if not os.path.exists(DB_PATH):
|
|
raise FileNotFoundError(f"Memory database not found: {DB_PATH}")
|
|
conn = sqlite3.connect(DB_PATH)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
def escape_fts5_query(query: str) -> str:
|
|
"""
|
|
Escape a query string for safe use in FTS5 MATCH.
|
|
|
|
FTS5 special characters that need handling:
|
|
- Quotes (" and ') have special meaning
|
|
- Hyphens (-) mean NOT operator
|
|
- Other operators: AND, OR, NOT, NEAR, *
|
|
|
|
Strategy: wrap each word in double quotes to treat as literal phrase.
|
|
"""
|
|
if not query or not query.strip():
|
|
return None
|
|
|
|
# Split into words and quote each one to treat as literal
|
|
# This handles hyphens, apostrophes, and other special chars
|
|
words = query.split()
|
|
# Escape any double quotes within words, then wrap in quotes
|
|
escaped_words = []
|
|
for word in words:
|
|
# Replace double quotes with escaped double quotes for FTS5
|
|
escaped = word.replace('"', '""')
|
|
escaped_words.append(f'"{escaped}"')
|
|
|
|
return " ".join(escaped_words)
|
|
|
|
|
|
def search_memories(
|
|
query: str,
|
|
guild_id: Optional[str] = None,
|
|
user_id: Optional[str] = None,
|
|
memory_type: Optional[str] = None,
|
|
days_back: int = 90,
|
|
limit: int = 10
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Search memories using FTS5 with optional filters.
|
|
|
|
Args:
|
|
query: Search query (plain text, automatically escaped for safety).
|
|
Special characters like hyphens and quotes are treated as literals.
|
|
guild_id: Filter to specific guild (None = all guilds)
|
|
user_id: Filter to specific user
|
|
memory_type: Filter to type (fact, preference, event, relationship, project, person, security)
|
|
days_back: Only search memories from last N days (0 = all time)
|
|
limit: Max results to return
|
|
|
|
Returns:
|
|
List of matching memories with scores
|
|
"""
|
|
# Validate and escape the query
|
|
safe_query = escape_fts5_query(query)
|
|
if safe_query is None:
|
|
return [] # Return empty results for empty/whitespace queries
|
|
|
|
db = get_db()
|
|
try:
|
|
# Build WHERE clause
|
|
conditions = ["memories_fts MATCH ?"]
|
|
params = [safe_query]
|
|
|
|
if guild_id:
|
|
conditions.append("(m.guild_id = ? OR m.guild_id IS NULL)")
|
|
params.append(guild_id)
|
|
|
|
if user_id:
|
|
conditions.append("m.user_id = ?")
|
|
params.append(user_id)
|
|
|
|
if memory_type:
|
|
conditions.append("m.memory_type = ?")
|
|
params.append(memory_type)
|
|
|
|
if days_back > 0:
|
|
cutoff = int((datetime.now() - timedelta(days=days_back)).timestamp())
|
|
conditions.append("m.created_at > ?")
|
|
params.append(cutoff)
|
|
|
|
where_clause = " AND ".join(conditions)
|
|
params.append(limit)
|
|
|
|
sql = f"""
|
|
SELECT
|
|
m.id,
|
|
m.content,
|
|
m.summary,
|
|
m.memory_type,
|
|
m.guild_id,
|
|
m.user_id,
|
|
m.confidence,
|
|
m.created_at,
|
|
m.access_count,
|
|
bm25(memories_fts) as fts_score
|
|
FROM memories m
|
|
JOIN memories_fts fts ON m.id = fts.rowid
|
|
WHERE {where_clause}
|
|
AND m.superseded_by IS NULL
|
|
ORDER BY
|
|
bm25(memories_fts),
|
|
m.confidence DESC,
|
|
m.created_at DESC
|
|
LIMIT ?
|
|
"""
|
|
|
|
cursor = db.execute(sql, params)
|
|
results = []
|
|
|
|
for row in cursor:
|
|
results.append({
|
|
"id": row["id"],
|
|
"content": row["content"],
|
|
"summary": row["summary"],
|
|
"memory_type": row["memory_type"],
|
|
"guild_id": row["guild_id"],
|
|
"user_id": row["user_id"],
|
|
"confidence": row["confidence"],
|
|
"created_at": row["created_at"],
|
|
"access_count": row["access_count"],
|
|
"fts_score": row["fts_score"],
|
|
})
|
|
|
|
# Update access counts
|
|
if results:
|
|
ids = [r["id"] for r in results]
|
|
placeholders = ",".join("?" * len(ids))
|
|
db.execute(f"""
|
|
UPDATE memories
|
|
SET last_accessed = ?, access_count = access_count + 1
|
|
WHERE id IN ({placeholders})
|
|
""", [int(datetime.now().timestamp())] + ids)
|
|
db.commit()
|
|
|
|
return results
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def get_recent_memories(
|
|
guild_id: Optional[str] = None,
|
|
limit: int = 10,
|
|
memory_type: Optional[str] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get most recent memories, optionally filtered by guild.
|
|
Useful for context loading without a specific query.
|
|
"""
|
|
db = get_db()
|
|
try:
|
|
conditions = ["superseded_by IS NULL"]
|
|
params = []
|
|
|
|
if guild_id:
|
|
conditions.append("(guild_id = ? OR guild_id IS NULL)")
|
|
params.append(guild_id)
|
|
|
|
if memory_type:
|
|
conditions.append("memory_type = ?")
|
|
params.append(memory_type)
|
|
|
|
where_clause = " AND ".join(conditions)
|
|
params.append(limit)
|
|
|
|
sql = f"""
|
|
SELECT
|
|
id, content, summary, memory_type, guild_id,
|
|
user_id, confidence, created_at, access_count
|
|
FROM memories
|
|
WHERE {where_clause}
|
|
ORDER BY created_at DESC
|
|
LIMIT ?
|
|
"""
|
|
|
|
cursor = db.execute(sql, params)
|
|
results = [dict(row) for row in cursor]
|
|
return results
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def add_memory(
|
|
content: str,
|
|
memory_type: str = "fact",
|
|
guild_id: Optional[str] = None,
|
|
channel_id: Optional[str] = None,
|
|
user_id: Optional[str] = None,
|
|
summary: Optional[str] = None,
|
|
source: str = "explicit",
|
|
confidence: float = 1.0
|
|
) -> int:
|
|
"""
|
|
Add a new memory to the database.
|
|
|
|
Returns:
|
|
ID of the created memory
|
|
"""
|
|
if not content or not content.strip():
|
|
raise ValueError("Content cannot be empty")
|
|
if not 0.0 <= confidence <= 1.0:
|
|
raise ValueError("Confidence must be between 0.0 and 1.0")
|
|
|
|
db = get_db()
|
|
try:
|
|
cursor = db.execute("""
|
|
INSERT INTO memories (
|
|
content, memory_type, guild_id, channel_id, user_id,
|
|
summary, source, confidence, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
content, memory_type, guild_id, channel_id, user_id,
|
|
summary, source, confidence, int(datetime.now().timestamp())
|
|
))
|
|
|
|
memory_id = cursor.lastrowid
|
|
db.commit()
|
|
return memory_id
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def supersede_memory(old_id: int, new_content: str, reason: str = "updated") -> int:
|
|
"""
|
|
Create a new memory that supersedes an old one.
|
|
The old memory is not deleted, just marked as superseded.
|
|
|
|
Returns:
|
|
ID of the new memory
|
|
"""
|
|
if not new_content or not new_content.strip():
|
|
raise ValueError("New content cannot be empty")
|
|
|
|
db = get_db()
|
|
try:
|
|
# Get the old memory's metadata
|
|
cursor = db.execute("""
|
|
SELECT memory_type, guild_id, channel_id, user_id, source
|
|
FROM memories WHERE id = ?
|
|
""", (old_id,))
|
|
old = cursor.fetchone()
|
|
|
|
if not old:
|
|
raise ValueError(f"Memory {old_id} not found")
|
|
|
|
# Create new memory and mark old as superseded in same transaction
|
|
cursor = db.execute("""
|
|
INSERT INTO memories (
|
|
content, memory_type, guild_id, channel_id, user_id,
|
|
summary, source, confidence, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, 1.0, ?)
|
|
""", (
|
|
new_content, old["memory_type"], old["guild_id"], old["channel_id"],
|
|
old["user_id"], f"Updated: {reason}", old["source"],
|
|
int(datetime.now().timestamp())
|
|
))
|
|
new_id = cursor.lastrowid
|
|
|
|
# Mark old as superseded
|
|
db.execute("UPDATE memories SET superseded_by = ? WHERE id = ?", (new_id, old_id))
|
|
|
|
db.commit()
|
|
return new_id
|
|
except Exception:
|
|
db.rollback()
|
|
raise
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def get_memory_stats() -> Dict[str, Any]:
|
|
"""Get statistics about the memory database."""
|
|
db = get_db()
|
|
try:
|
|
stats = {}
|
|
|
|
# Total counts
|
|
cursor = db.execute("SELECT COUNT(*) FROM memories WHERE superseded_by IS NULL")
|
|
stats["total_active"] = cursor.fetchone()[0]
|
|
|
|
cursor = db.execute("SELECT COUNT(*) FROM memories WHERE superseded_by IS NOT NULL")
|
|
stats["total_superseded"] = cursor.fetchone()[0]
|
|
|
|
# By type
|
|
cursor = db.execute("""
|
|
SELECT memory_type, COUNT(*)
|
|
FROM memories
|
|
WHERE superseded_by IS NULL
|
|
GROUP BY memory_type
|
|
""")
|
|
stats["by_type"] = {row[0]: row[1] for row in cursor}
|
|
|
|
# By guild
|
|
cursor = db.execute("""
|
|
SELECT
|
|
COALESCE(guild_id, 'global') as guild,
|
|
COUNT(*)
|
|
FROM memories
|
|
WHERE superseded_by IS NULL
|
|
GROUP BY guild_id
|
|
""")
|
|
stats["by_guild"] = {row[0]: row[1] for row in cursor}
|
|
|
|
return stats
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# CLI interface for testing
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
def get_cli_arg(flag: str, default=None):
|
|
"""Safely get CLI argument value after a flag."""
|
|
if flag not in sys.argv:
|
|
return default
|
|
idx = sys.argv.index(flag)
|
|
if idx + 1 >= len(sys.argv):
|
|
print(f"Error: {flag} requires a value")
|
|
sys.exit(1)
|
|
return sys.argv[idx + 1]
|
|
|
|
def safe_truncate(text: str, length: int = 100) -> str:
|
|
"""Safely truncate text, handling None values."""
|
|
if text is None:
|
|
return "(empty)"
|
|
text = text.replace('\n', ' ') # Single line
|
|
if len(text) <= length:
|
|
return text
|
|
return text[:length] + "..."
|
|
|
|
if len(sys.argv) < 2:
|
|
print("Usage:")
|
|
print(" python memory-retrieval.py search <query> [--guild <id>]")
|
|
print(" python memory-retrieval.py recent [--guild <id>] [--limit <n>]")
|
|
print(" python memory-retrieval.py stats")
|
|
print(" python memory-retrieval.py add <content> [--type <type>] [--guild <id>]")
|
|
sys.exit(1)
|
|
|
|
cmd = sys.argv[1]
|
|
|
|
if cmd == "search" and len(sys.argv) >= 3:
|
|
query = sys.argv[2]
|
|
guild_id = get_cli_arg("--guild")
|
|
|
|
results = search_memories(query, guild_id=guild_id)
|
|
print(f"Found {len(results)} memories:\n")
|
|
for r in results:
|
|
print(f"[{r['id']}] ({r['memory_type']}) {safe_truncate(r['content'])}")
|
|
print(f" Score: {r['fts_score']:.4f}, Confidence: {r['confidence']}")
|
|
print()
|
|
|
|
elif cmd == "recent":
|
|
guild_id = get_cli_arg("--guild")
|
|
limit_str = get_cli_arg("--limit", "10")
|
|
try:
|
|
limit = int(limit_str)
|
|
except ValueError:
|
|
print(f"Error: --limit must be a number, got '{limit_str}'")
|
|
sys.exit(1)
|
|
|
|
results = get_recent_memories(guild_id=guild_id, limit=limit)
|
|
print(f"Recent {len(results)} memories:\n")
|
|
for r in results:
|
|
print(f"[{r['id']}] ({r['memory_type']}) {safe_truncate(r['content'])}")
|
|
print()
|
|
|
|
elif cmd == "stats":
|
|
stats = get_memory_stats()
|
|
print(json.dumps(stats, indent=2))
|
|
|
|
elif cmd == "add" and len(sys.argv) >= 3:
|
|
content = sys.argv[2]
|
|
memory_type = get_cli_arg("--type", "fact")
|
|
guild_id = get_cli_arg("--guild")
|
|
|
|
memory_id = add_memory(content, memory_type=memory_type, guild_id=guild_id)
|
|
print(f"Added memory with ID: {memory_id}")
|
|
|
|
else:
|
|
print("Unknown command or missing arguments")
|
|
sys.exit(1)
|