Update hf_demo.py
Browse files- hf_demo.py +35 -18
hf_demo.py
CHANGED
|
@@ -364,12 +364,40 @@ class PolicyEngine:
|
|
| 364 |
return True
|
| 365 |
return False
|
| 366 |
|
|
|
|
|
|
|
|
|
|
| 367 |
class RAGMemory:
|
| 368 |
-
"""Persistent RAG memory with SQLite and
|
| 369 |
def __init__(self):
|
| 370 |
self.db_path = f"{settings.data_dir}/memory.db"
|
| 371 |
self._init_db()
|
| 372 |
self.embedding_cache = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
def _init_db(self):
|
| 375 |
try:
|
|
@@ -420,25 +448,12 @@ class RAGMemory:
|
|
| 420 |
if conn:
|
| 421 |
conn.close()
|
| 422 |
|
| 423 |
-
def _simple_embedding(self, text: str) -> List[float]:
|
| 424 |
-
if text in self.embedding_cache:
|
| 425 |
-
return self.embedding_cache[text]
|
| 426 |
-
words = text.lower().split()
|
| 427 |
-
trigrams = set()
|
| 428 |
-
for word in words:
|
| 429 |
-
for i in range(len(word) - 2):
|
| 430 |
-
trigrams.add(word[i:i+3])
|
| 431 |
-
vector = [hash(t) % 1000 / 1000.0 for t in sorted(trigrams)[:100]]
|
| 432 |
-
while len(vector) < 100:
|
| 433 |
-
vector.append(0.0)
|
| 434 |
-
vector = vector[:100]
|
| 435 |
-
self.embedding_cache[text] = vector
|
| 436 |
-
return vector
|
| 437 |
-
|
| 438 |
def store_incident(self, action: str, risk_score: float, risk_level: RiskLevel,
|
| 439 |
confidence: float, allowed: bool, gates: List[Dict]):
|
| 440 |
action_hash = hashlib.sha256(action.encode()).hexdigest()[:50]
|
| 441 |
-
|
|
|
|
|
|
|
| 442 |
try:
|
| 443 |
with self._get_db() as conn:
|
| 444 |
conn.execute('''
|
|
@@ -462,7 +477,9 @@ class RAGMemory:
|
|
| 462 |
logger.error(f"Failed to store incident: {e}")
|
| 463 |
|
| 464 |
def find_similar(self, action: str, limit: int = 5) -> List[Dict]:
|
| 465 |
-
|
|
|
|
|
|
|
| 466 |
try:
|
| 467 |
with self._get_db() as conn:
|
| 468 |
cursor = conn.execute('SELECT * FROM incidents ORDER BY timestamp DESC LIMIT 100')
|
|
|
|
| 364 |
return True
|
| 365 |
return False
|
| 366 |
|
| 367 |
+
# ==============================================================================
|
| 368 |
+
# UPGRADED RAG MEMORY WITH SENTENCE-TRANSFORMERS
|
| 369 |
+
# ==============================================================================
|
| 370 |
class RAGMemory:
|
| 371 |
+
"""Persistent RAG memory with SQLite and sentence‑transformer embeddings."""
|
| 372 |
def __init__(self):
|
| 373 |
self.db_path = f"{settings.data_dir}/memory.db"
|
| 374 |
self._init_db()
|
| 375 |
self.embedding_cache = {}
|
| 376 |
+
self._sentence_model = None # lazy loaded
|
| 377 |
+
|
| 378 |
+
def _get_sentence_model(self):
|
| 379 |
+
"""Lazy load the sentence‑transformer model."""
|
| 380 |
+
if self._sentence_model is None:
|
| 381 |
+
from sentence_transformers import SentenceTransformer
|
| 382 |
+
# Using all-MiniLM-L6-v2 – fast and good for semantic similarity
|
| 383 |
+
self._sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 384 |
+
return self._sentence_model
|
| 385 |
+
|
| 386 |
+
def _build_incident_text(self, action: str) -> str:
|
| 387 |
+
"""Create a descriptive text from the action."""
|
| 388 |
+
# You can enrich this with more context (risk level, component, etc.)
|
| 389 |
+
return f"Action: {action}"
|
| 390 |
+
|
| 391 |
+
def _simple_embedding(self, text: str) -> List[float]:
|
| 392 |
+
"""Generate embedding using sentence‑transformer."""
|
| 393 |
+
if text in self.embedding_cache:
|
| 394 |
+
return self.embedding_cache[text]
|
| 395 |
+
|
| 396 |
+
model = self._get_sentence_model()
|
| 397 |
+
# encode returns a numpy array; convert to list for JSON storage
|
| 398 |
+
embedding = model.encode(text, convert_to_numpy=True).tolist()
|
| 399 |
+
self.embedding_cache[text] = embedding
|
| 400 |
+
return embedding
|
| 401 |
|
| 402 |
def _init_db(self):
|
| 403 |
try:
|
|
|
|
| 448 |
if conn:
|
| 449 |
conn.close()
|
| 450 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
def store_incident(self, action: str, risk_score: float, risk_level: RiskLevel,
|
| 452 |
confidence: float, allowed: bool, gates: List[Dict]):
|
| 453 |
action_hash = hashlib.sha256(action.encode()).hexdigest()[:50]
|
| 454 |
+
# Build a descriptive text and generate embedding
|
| 455 |
+
incident_text = self._build_incident_text(action)
|
| 456 |
+
embedding = json.dumps(self._simple_embedding(incident_text))
|
| 457 |
try:
|
| 458 |
with self._get_db() as conn:
|
| 459 |
conn.execute('''
|
|
|
|
| 477 |
logger.error(f"Failed to store incident: {e}")
|
| 478 |
|
| 479 |
def find_similar(self, action: str, limit: int = 5) -> List[Dict]:
|
| 480 |
+
# Build query embedding from the action text
|
| 481 |
+
query_text = self._build_incident_text(action)
|
| 482 |
+
query_embedding = self._simple_embedding(query_text)
|
| 483 |
try:
|
| 484 |
with self._get_db() as conn:
|
| 485 |
cursor = conn.execute('SELECT * FROM incidents ORDER BY timestamp DESC LIMIT 100')
|