275 lines
8.8 KiB
Python

import sqlite3
import json
import numpy as np
from typing import List, Dict, Any, Optional
class DatabaseManager:
def __init__(self, db_path: str = "constitution.db"):
self.db_path = db_path
self.conn = sqlite3.connect(db_path)
self.conn.row_factory = sqlite3.Row
def create_tables(self):
cursor = self.conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS sections (
id INTEGER PRIMARY KEY,
section_type TEXT,
parent_id INTEGER,
title TEXT,
content TEXT,
line_start INTEGER,
line_end INTEGER,
hierarchy_level INTEGER,
path TEXT,
FOREIGN KEY (parent_id) REFERENCES sections(id)
);
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS variables (
id INTEGER PRIMARY KEY,
name TEXT UNIQUE,
category TEXT,
priority_level INTEGER,
is_hard_constraint BOOLEAN,
principal_assignment TEXT,
frequency INTEGER DEFAULT 0,
description TEXT
);
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS variable_occurrences (
id INTEGER PRIMARY KEY,
variable_id INTEGER,
section_id INTEGER,
sentence_id INTEGER,
context TEXT,
FOREIGN KEY (variable_id) REFERENCES variables(id),
FOREIGN KEY (section_id) REFERENCES sections(id)
);
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS sentences (
id INTEGER PRIMARY KEY,
section_id INTEGER,
text TEXT,
sentence_number INTEGER,
line_number INTEGER,
FOREIGN KEY (section_id) REFERENCES sections(id)
);
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
id INTEGER PRIMARY KEY,
content_id INTEGER,
content_type TEXT,
embedding BLOB,
embedding_dim INTEGER DEFAULT 768,
chunk_start INTEGER,
chunk_end INTEGER,
FOREIGN KEY (content_id) REFERENCES sections(id) ON DELETE CASCADE
);
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS similarity (
id INTEGER PRIMARY KEY,
content_id_1 INTEGER,
content_id_2 INTEGER,
similarity_score REAL,
FOREIGN KEY (content_id_1) REFERENCES sections(id),
FOREIGN KEY (content_id_2) REFERENCES sections(id)
);
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS statistics (
id INTEGER PRIMARY KEY,
metric_name TEXT UNIQUE,
metric_value REAL,
json_data TEXT
);
""")
self.conn.commit()
def populate(
self,
sections: List[Dict],
sentences: List[Dict],
variables: List[Dict],
constraints: List[Dict],
):
cursor = self.conn.cursor()
section_id_map = {}
for i, section in enumerate(sections, 1):
parent_id = (
section_id_map.get(section.get("parent_id"))
if section.get("parent_id")
else None
)
cursor.execute(
"""
INSERT INTO sections (id, section_type, parent_id, title, content, line_start, line_end, hierarchy_level, path)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
i,
section.get("section_type"),
parent_id,
section.get("title"),
section.get("content"),
section.get("line_start"),
section.get("line_end"),
section.get("hierarchy_level"),
section.get("path"),
),
)
section_id_map[i] = i
for i, sentence in enumerate(sentences, 1):
section_ref_id = section.get("section_id")
cursor.execute(
"""
INSERT INTO sentences (id, section_id, text, sentence_number, line_number)
VALUES (?, ?, ?, ?, ?)
""",
(
i,
section_ref_id,
sentence.get("text"),
sentence.get("sentence_number"),
sentence.get("line_number"),
),
)
var_id_map = {}
for i, var in enumerate(variables, 1):
cursor.execute(
"""
INSERT INTO variables (id, name, category, priority_level, is_hard_constraint, principal_assignment, frequency, description)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
i,
var.get("name"),
var.get("category"),
var.get("priority_level"),
var.get("is_hard_constraint"),
var.get("principal_assignment"),
var.get("frequency", 0),
var.get("description"),
),
)
var_id_map[var.get("name")] = i
self.conn.commit()
def get_sections_with_embeddings(self) -> List[Dict]:
cursor = self.conn.cursor()
cursor.execute("SELECT * FROM sections")
rows = cursor.fetchall()
sections = []
for row in rows:
section = dict(row)
cursor.execute(
"SELECT * FROM embeddings WHERE content_id = ?", (section["id"],)
)
embeddings = cursor.fetchall()
section["embeddings"] = [dict(e) for e in embeddings]
sections.append(section)
return sections
def get_variables(self) -> List[Dict]:
cursor = self.conn.cursor()
cursor.execute("SELECT * FROM variables")
return [dict(row) for row in cursor.fetchall()]
def get_sentences(self) -> List[Dict]:
cursor = self.conn.cursor()
cursor.execute("SELECT * FROM sentences")
return [dict(row) for row in cursor.fetchall()]
def add_embedding(
self,
content_id: int,
content_type: str,
embedding: np.ndarray,
chunk_start: Optional[int] = None,
chunk_end: Optional[int] = None,
):
cursor = self.conn.cursor()
embedding_blob = embedding.tobytes()
cursor.execute(
"""
INSERT INTO embeddings (content_id, content_type, embedding, embedding_dim, chunk_start, chunk_end)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
content_id,
content_type,
embedding_blob,
len(embedding),
chunk_start,
chunk_end,
),
)
self.conn.commit()
def get_embedding(self, content_id: int) -> Optional[np.ndarray]:
cursor = self.conn.cursor()
cursor.execute(
"SELECT embedding FROM embeddings WHERE content_id = ? LIMIT 1",
(content_id,),
)
row = cursor.fetchone()
if row:
return np.frombuffer(row["embedding"], dtype=np.float32)
return None
def add_similarity(self, content_id_1: int, content_id_2: int, score: float):
cursor = self.conn.cursor()
cursor.execute(
"""
INSERT INTO similarity (content_id_1, content_id_2, similarity_score)
VALUES (?, ?, ?)
""",
(content_id_1, content_id_2, score),
)
self.conn.commit()
def get_statistics(self, metric_name: str) -> Optional[Dict]:
cursor = self.conn.cursor()
cursor.execute("SELECT * FROM statistics WHERE metric_name = ?", (metric_name,))
row = cursor.fetchone()
if row:
data = dict(row)
if data.get("json_data"):
data["json_data"] = json.loads(data["json_data"])
return data
return None
def set_statistics(
self, metric_name: str, metric_value: float, json_data: Optional[Dict] = None
):
cursor = self.conn.cursor()
json_str = json.dumps(json_data) if json_data else None
cursor.execute(
"""
INSERT OR REPLACE INTO statistics (metric_name, metric_value, json_data)
VALUES (?, ?, ?)
""",
(metric_name, metric_value, json_str),
)
self.conn.commit()
def close(self):
self.conn.close()