petter2025 commited on
Commit
ff4d74f
·
verified ·
1 Parent(s): ed10db5

Update hf_demo.py

Browse files
Files changed (1) hide show
  1. hf_demo.py +35 -18
hf_demo.py CHANGED
@@ -364,12 +364,40 @@ class PolicyEngine:
364
  return True
365
  return False
366
 
 
 
 
367
  class RAGMemory:
368
- """Persistent RAG memory with SQLite and simple embeddings."""
369
  def __init__(self):
370
  self.db_path = f"{settings.data_dir}/memory.db"
371
  self._init_db()
372
  self.embedding_cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  def _init_db(self):
375
  try:
@@ -420,25 +448,12 @@ class RAGMemory:
420
  if conn:
421
  conn.close()
422
 
423
- def _simple_embedding(self, text: str) -> List[float]:
424
- if text in self.embedding_cache:
425
- return self.embedding_cache[text]
426
- words = text.lower().split()
427
- trigrams = set()
428
- for word in words:
429
- for i in range(len(word) - 2):
430
- trigrams.add(word[i:i+3])
431
- vector = [hash(t) % 1000 / 1000.0 for t in sorted(trigrams)[:100]]
432
- while len(vector) < 100:
433
- vector.append(0.0)
434
- vector = vector[:100]
435
- self.embedding_cache[text] = vector
436
- return vector
437
-
438
  def store_incident(self, action: str, risk_score: float, risk_level: RiskLevel,
439
  confidence: float, allowed: bool, gates: List[Dict]):
440
  action_hash = hashlib.sha256(action.encode()).hexdigest()[:50]
441
- embedding = json.dumps(self._simple_embedding(action))
 
 
442
  try:
443
  with self._get_db() as conn:
444
  conn.execute('''
@@ -462,7 +477,9 @@ class RAGMemory:
462
  logger.error(f"Failed to store incident: {e}")
463
 
464
  def find_similar(self, action: str, limit: int = 5) -> List[Dict]:
465
- query_embedding = self._simple_embedding(action)
 
 
466
  try:
467
  with self._get_db() as conn:
468
  cursor = conn.execute('SELECT * FROM incidents ORDER BY timestamp DESC LIMIT 100')
 
364
  return True
365
  return False
366
 
367
+ # ==============================================================================
368
+ # UPGRADED RAG MEMORY WITH SENTENCE-TRANSFORMERS
369
+ # ==============================================================================
370
  class RAGMemory:
371
+ """Persistent RAG memory with SQLite and sentence‑transformer embeddings."""
372
  def __init__(self):
373
  self.db_path = f"{settings.data_dir}/memory.db"
374
  self._init_db()
375
  self.embedding_cache = {}
376
+ self._sentence_model = None # lazy loaded
377
+
378
+ def _get_sentence_model(self):
379
+ """Lazy load the sentence‑transformer model."""
380
+ if self._sentence_model is None:
381
+ from sentence_transformers import SentenceTransformer
382
+ # Using all-MiniLM-L6-v2 – fast and good for semantic similarity
383
+ self._sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
384
+ return self._sentence_model
385
+
386
+ def _build_incident_text(self, action: str) -> str:
387
+ """Create a descriptive text from the action."""
388
+ # You can enrich this with more context (risk level, component, etc.)
389
+ return f"Action: {action}"
390
+
391
+ def _simple_embedding(self, text: str) -> List[float]:
392
+ """Generate embedding using sentence‑transformer."""
393
+ if text in self.embedding_cache:
394
+ return self.embedding_cache[text]
395
+
396
+ model = self._get_sentence_model()
397
+ # encode returns a numpy array; convert to list for JSON storage
398
+ embedding = model.encode(text, convert_to_numpy=True).tolist()
399
+ self.embedding_cache[text] = embedding
400
+ return embedding
401
 
402
  def _init_db(self):
403
  try:
 
448
  if conn:
449
  conn.close()
450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  def store_incident(self, action: str, risk_score: float, risk_level: RiskLevel,
452
  confidence: float, allowed: bool, gates: List[Dict]):
453
  action_hash = hashlib.sha256(action.encode()).hexdigest()[:50]
454
+ # Build a descriptive text and generate embedding
455
+ incident_text = self._build_incident_text(action)
456
+ embedding = json.dumps(self._simple_embedding(incident_text))
457
  try:
458
  with self._get_db() as conn:
459
  conn.execute('''
 
477
  logger.error(f"Failed to store incident: {e}")
478
 
479
  def find_similar(self, action: str, limit: int = 5) -> List[Dict]:
480
+ # Build query embedding from the action text
481
+ query_text = self._build_incident_text(action)
482
+ query_embedding = self._simple_embedding(query_text)
483
  try:
484
  with self._get_db() as conn:
485
  cursor = conn.execute('SELECT * FROM incidents ORDER BY timestamp DESC LIMIT 100')