|
|
""" |
|
|
Module for managing PostgreSQL database operations. |
|
|
- Replaces SQLite (sq_db.py) with PostgreSQL using SQLAlchemy |
|
|
- Provides functions to track user uploaded/generated files/data |
|
|
- Includes creating tables, adding files and embeddings, and managing users |
|
|
""" |
|
|
|
|
|
import bcrypt |
|
|
import os |
|
|
from typing import List, Optional, Dict, Any |
|
|
from datetime import datetime, timedelta |
|
|
from contextlib import contextmanager |
|
|
|
|
|
import pytz |
|
|
from sqlalchemy import create_engine, Column, String, Integer, DateTime, Boolean, Text, ForeignKey |
|
|
from sqlalchemy.ext.declarative import declarative_base |
|
|
from sqlalchemy.orm import sessionmaker, Session, relationship |
|
|
from logger import get_logger |
|
|
|
|
|
log = get_logger(name="pg_db") |
|
|
|
|
|
|
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://raguser:ragpass@localhost:5432/ragdb") |
|
|
|
|
|
|
|
|
engine = create_engine(DATABASE_URL, pool_pre_ping=True, echo=False) |
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
|
|
Base = declarative_base() |
|
|
|
|
|
IST = pytz.timezone('Asia/Kolkata') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class User(Base): |
|
|
"""User model""" |
|
|
__tablename__ = "users" |
|
|
|
|
|
user_id = Column(String, primary_key=True, index=True) |
|
|
password_hash = Column(String, nullable=False) |
|
|
created_at = Column(DateTime, default=lambda: datetime.now(IST)) |
|
|
last_login = Column(DateTime, nullable=True) |
|
|
available = Column(Boolean, default=True) |
|
|
|
|
|
|
|
|
files = relationship("UserFile", back_populates="user") |
|
|
embeddings = relationship("UserEmbedding", back_populates="user") |
|
|
|
|
|
|
|
|
class UserFile(Base): |
|
|
"""User uploaded files model""" |
|
|
__tablename__ = "user_files" |
|
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True) |
|
|
user_id = Column(String, ForeignKey("users.user_id"), nullable=False, index=True) |
|
|
file_name = Column(String, nullable=False) |
|
|
file_path = Column(String, nullable=False) |
|
|
file_type = Column(String, nullable=True) |
|
|
uploaded_at = Column(DateTime, default=lambda: datetime.now(IST)) |
|
|
available = Column(Boolean, default=True) |
|
|
|
|
|
|
|
|
user = relationship("User", back_populates="files") |
|
|
|
|
|
|
|
|
class UserEmbedding(Base): |
|
|
"""User embeddings/documents model""" |
|
|
__tablename__ = "user_embeddings" |
|
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True) |
|
|
user_id = Column(String, ForeignKey("users.user_id"), nullable=False, index=True) |
|
|
qdrant_doc_id = Column(String, nullable=True) |
|
|
source = Column(String, nullable=True) |
|
|
created_at = Column(DateTime, default=lambda: datetime.now(IST)) |
|
|
available = Column(Boolean, default=True) |
|
|
|
|
|
|
|
|
user = relationship("User", back_populates="embeddings") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_database(): |
|
|
"""Initialize the database by creating all tables.""" |
|
|
try: |
|
|
Base.metadata.create_all(bind=engine) |
|
|
log.info("Database tables created successfully.") |
|
|
return True |
|
|
except Exception as e: |
|
|
log.error(f"Error creating database tables: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def get_db(): |
|
|
"""Context manager for database sessions.""" |
|
|
db = SessionLocal() |
|
|
try: |
|
|
yield db |
|
|
db.commit() |
|
|
except Exception: |
|
|
db.rollback() |
|
|
raise |
|
|
finally: |
|
|
db.close() |
|
|
|
|
|
|
|
|
def get_connection(): |
|
|
"""Returns a new database session. (For compatibility with existing code)""" |
|
|
return SessionLocal() |
|
|
|
|
|
|
|
|
def delete_database() -> bool: |
|
|
"""Drop all tables (equivalent to deleting SQLite file).""" |
|
|
try: |
|
|
Base.metadata.drop_all(bind=engine) |
|
|
log.info("All database tables dropped successfully.") |
|
|
return True |
|
|
except Exception as e: |
|
|
log.error(f"Error dropping database tables: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_user(user_id: str, password: str) -> bool: |
|
|
"""Create a new user with hashed password.""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
|
|
|
existing_user = db.query(User).filter(User.user_id == user_id).first() |
|
|
if existing_user: |
|
|
log.warning(f"User '{user_id}' already exists.") |
|
|
return False |
|
|
|
|
|
|
|
|
password_hash = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') |
|
|
|
|
|
|
|
|
user = User(user_id=user_id, password_hash=password_hash) |
|
|
db.add(user) |
|
|
|
|
|
log.info(f"User '{user_id}' created successfully.") |
|
|
return True |
|
|
except Exception as e: |
|
|
log.error(f"Error creating user '{user_id}': {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def verify_user(user_id: str, password: str) -> bool: |
|
|
"""Verify user credentials.""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
user = db.query(User).filter(User.user_id == user_id, User.available == True).first() |
|
|
|
|
|
if not user: |
|
|
log.warning(f"User '{user_id}' not found.") |
|
|
return False |
|
|
|
|
|
|
|
|
if bcrypt.checkpw(password.encode('utf-8'), user.password_hash.encode('utf-8')): |
|
|
|
|
|
user.last_login = datetime.now(IST) |
|
|
log.info(f"User '{user_id}' authenticated successfully.") |
|
|
return True |
|
|
else: |
|
|
log.warning(f"Invalid password for user '{user_id}'.") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"Error verifying user '{user_id}': {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def get_all_users() -> List[Dict[str, Any]]: |
|
|
"""Get all users.""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
users = db.query(User).filter(User.available == True).all() |
|
|
return [ |
|
|
{ |
|
|
"user_id": u.user_id, |
|
|
"created_at": u.created_at.isoformat() if u.created_at else None, |
|
|
"last_login": u.last_login.isoformat() if u.last_login else None |
|
|
} |
|
|
for u in users |
|
|
] |
|
|
except Exception as e: |
|
|
log.error(f"Error getting all users: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
def delete_user(user_id: str) -> bool: |
|
|
"""Mark user as unavailable (soft delete).""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
user = db.query(User).filter(User.user_id == user_id).first() |
|
|
if user: |
|
|
user.available = False |
|
|
log.info(f"User '{user_id}' marked as deleted.") |
|
|
return True |
|
|
else: |
|
|
log.warning(f"User '{user_id}' not found.") |
|
|
return False |
|
|
except Exception as e: |
|
|
log.error(f"Error deleting user '{user_id}': {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_file(user_id: str, file_name: str, file_path: str, file_type: Optional[str] = None) -> bool: |
|
|
"""Add a file record for a user.""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
file_record = UserFile( |
|
|
user_id=user_id, |
|
|
file_name=file_name, |
|
|
file_path=file_path, |
|
|
file_type=file_type |
|
|
) |
|
|
db.add(file_record) |
|
|
|
|
|
log.info(f"File '{file_name}' added for user '{user_id}'.") |
|
|
return True |
|
|
except Exception as e: |
|
|
log.error(f"Error adding file '{file_name}' for user '{user_id}': {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def get_user_files(user_id: str) -> List[Dict[str, Any]]: |
|
|
"""Get all files for a user.""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
files = db.query(UserFile).filter( |
|
|
UserFile.user_id == user_id, |
|
|
UserFile.available == True |
|
|
).all() |
|
|
|
|
|
return [ |
|
|
{ |
|
|
"id": f.id, |
|
|
"file_name": f.file_name, |
|
|
"file_path": f.file_path, |
|
|
"file_type": f.file_type, |
|
|
"uploaded_at": f.uploaded_at.isoformat() if f.uploaded_at else None |
|
|
} |
|
|
for f in files |
|
|
] |
|
|
except Exception as e: |
|
|
log.error(f"Error getting files for user '{user_id}': {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
def delete_file(file_id: int) -> bool: |
|
|
"""Mark file as unavailable (soft delete).""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
file_record = db.query(UserFile).filter(UserFile.id == file_id).first() |
|
|
if file_record: |
|
|
file_record.available = False |
|
|
log.info(f"File ID {file_id} marked as deleted.") |
|
|
return True |
|
|
else: |
|
|
log.warning(f"File ID {file_id} not found.") |
|
|
return False |
|
|
except Exception as e: |
|
|
log.error(f"Error deleting file ID {file_id}: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_embedding(user_id: str, qdrant_doc_id: Optional[str] = None, source: Optional[str] = None) -> bool: |
|
|
"""Add an embedding record for a user.""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
embedding = UserEmbedding( |
|
|
user_id=user_id, |
|
|
qdrant_doc_id=qdrant_doc_id, |
|
|
source=source |
|
|
) |
|
|
db.add(embedding) |
|
|
|
|
|
log.info(f"Embedding added for user '{user_id}'.") |
|
|
return True |
|
|
except Exception as e: |
|
|
log.error(f"Error adding embedding for user '{user_id}': {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def get_user_embeddings(user_id: str) -> List[Dict[str, Any]]: |
|
|
"""Get all embeddings for a user.""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
embeddings = db.query(UserEmbedding).filter( |
|
|
UserEmbedding.user_id == user_id, |
|
|
UserEmbedding.available == True |
|
|
).all() |
|
|
|
|
|
return [ |
|
|
{ |
|
|
"id": e.id, |
|
|
"qdrant_doc_id": e.qdrant_doc_id, |
|
|
"source": e.source, |
|
|
"created_at": e.created_at.isoformat() if e.created_at else None |
|
|
} |
|
|
for e in embeddings |
|
|
] |
|
|
except Exception as e: |
|
|
log.error(f"Error getting embeddings for user '{user_id}': {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
def delete_embedding(embedding_id: int) -> bool: |
|
|
"""Mark embedding as unavailable (soft delete).""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
embedding = db.query(UserEmbedding).filter(UserEmbedding.id == embedding_id).first() |
|
|
if embedding: |
|
|
embedding.available = False |
|
|
log.info(f"Embedding ID {embedding_id} marked as deleted.") |
|
|
return True |
|
|
else: |
|
|
log.warning(f"Embedding ID {embedding_id} not found.") |
|
|
return False |
|
|
except Exception as e: |
|
|
log.error(f"Error deleting embedding ID {embedding_id}: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_tables(): |
|
|
"""Alias for init_database() to match sq_db.py API.""" |
|
|
init_database() |
|
|
|
|
|
|
|
|
def add_user(user_id: str, name: str, password: str) -> bool: |
|
|
""" |
|
|
Add user with name parameter (sq_db.py compatibility). |
|
|
Note: name parameter is ignored in PostgreSQL version as schema doesn't include it. |
|
|
""" |
|
|
return create_user(user_id=user_id, password=password) |
|
|
|
|
|
|
|
|
def check_user_exists(user_id: str) -> bool: |
|
|
"""Check if user exists in database.""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
user = db.query(User).filter(User.user_id == user_id, User.available == True).first() |
|
|
return user is not None |
|
|
except Exception as e: |
|
|
log.error(f"Error checking if user '{user_id}' exists: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def authenticate_user(user_id: str, password: str) -> tuple[bool, str]: |
|
|
""" |
|
|
Authenticate user and return status with message. |
|
|
Returns: (success: bool, message: str) |
|
|
""" |
|
|
try: |
|
|
if verify_user(user_id=user_id, password=password): |
|
|
|
|
|
with get_db() as db: |
|
|
user = db.query(User).filter(User.user_id == user_id).first() |
|
|
if user: |
|
|
user.last_login = datetime.now() |
|
|
log.info(f"User '{user_id}' authenticated successfully.") |
|
|
return (True, "Authentication successful") |
|
|
else: |
|
|
return (False, "Invalid credentials") |
|
|
except Exception as e: |
|
|
log.error(f"Error authenticating user '{user_id}': {e}") |
|
|
return (False, f"Error: {str(e)}") |
|
|
|
|
|
|
|
|
def get_file_id_by_name(user_id: str, file_name: str) -> int: |
|
|
"""Get file ID by user ID and filename.""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
file = db.query(UserFile).filter( |
|
|
UserFile.user_id == user_id, |
|
|
UserFile.file_name == file_name, |
|
|
UserFile.available == True |
|
|
).first() |
|
|
return file.id if file else -1 |
|
|
except Exception as e: |
|
|
log.error(f"Error getting file ID for '{file_name}': {e}") |
|
|
return -1 |
|
|
|
|
|
|
|
|
def mark_file_removed(user_id: str, file_id: int) -> bool: |
|
|
"""Mark file as unavailable (soft delete).""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
file = db.query(UserFile).filter( |
|
|
UserFile.id == file_id, |
|
|
UserFile.user_id == user_id |
|
|
).first() |
|
|
if file: |
|
|
file.available = False |
|
|
log.info(f"File ID {file_id} marked as removed.") |
|
|
return True |
|
|
else: |
|
|
log.warning(f"File ID {file_id} not found for user '{user_id}'.") |
|
|
return False |
|
|
except Exception as e: |
|
|
log.error(f"Error marking file ID {file_id} as removed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def mark_embeddings_removed(vector_ids: List[str]) -> bool: |
|
|
"""Mark embeddings as unavailable by qdrant_doc_id list.""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
embeddings = db.query(UserEmbedding).filter( |
|
|
UserEmbedding.qdrant_doc_id.in_(vector_ids) |
|
|
).all() |
|
|
for embedding in embeddings: |
|
|
embedding.available = False |
|
|
log.info(f"Marked {len(embeddings)} embeddings as removed.") |
|
|
return True |
|
|
except Exception as e: |
|
|
log.error(f"Error marking embeddings as removed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def get_old_files(user_id: str, time: int = 12*3600) -> dict: |
|
|
""" |
|
|
Get files older than specified time (in seconds). |
|
|
Returns: dict with 'files' (list of filenames) and 'embeddings' (list of qdrant_doc_ids) |
|
|
""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
cutoff_time = datetime.now() - timedelta(seconds=time) |
|
|
old_files = db.query(UserFile).filter( |
|
|
UserFile.user_id == user_id, |
|
|
UserFile.uploaded_at < cutoff_time, |
|
|
UserFile.available == True |
|
|
).all() |
|
|
|
|
|
filenames = [f.file_name for f in old_files] |
|
|
file_ids = [f.id for f in old_files] |
|
|
|
|
|
|
|
|
if file_ids: |
|
|
embeddings = db.query(UserEmbedding).join(UserFile).filter( |
|
|
UserFile.id.in_(file_ids), |
|
|
UserEmbedding.available == True |
|
|
).all() |
|
|
embedding_ids = [e.qdrant_doc_id for e in embeddings if e.qdrant_doc_id] |
|
|
else: |
|
|
embedding_ids = [] |
|
|
|
|
|
return { |
|
|
'files': filenames, |
|
|
'embeddings': embedding_ids |
|
|
} |
|
|
except Exception as e: |
|
|
log.error(f"Error getting old files for user '{user_id}': {e}") |
|
|
return {'files': [], 'embeddings': []} |
|
|
|
|
|
|
|
|
|
|
|
def add_file_compat(user_id: str, filename: str) -> int: |
|
|
""" |
|
|
Add file with only filename parameter (sq_db.py compatibility). |
|
|
Returns file ID. |
|
|
""" |
|
|
try: |
|
|
file_path = f"/fastAPI/user_uploads/{user_id}/{filename}" |
|
|
success = add_file(user_id=user_id, file_name=filename, file_path=file_path) |
|
|
if success: |
|
|
file_id = get_file_id_by_name(user_id=user_id, file_name=filename) |
|
|
return file_id |
|
|
return -1 |
|
|
except Exception as e: |
|
|
log.error(f"Error adding file '{filename}': {e}") |
|
|
return -1 |
|
|
|
|
|
|
|
|
|
|
|
def add_embedding_compat(file_id: int, vector_id: str) -> bool: |
|
|
""" |
|
|
Add embedding with file_id and vector_id (sq_db.py compatibility). |
|
|
""" |
|
|
try: |
|
|
with get_db() as db: |
|
|
file = db.query(UserFile).filter(UserFile.id == file_id).first() |
|
|
if not file: |
|
|
log.warning(f"File ID {file_id} not found.") |
|
|
return False |
|
|
|
|
|
user_id = file.user_id |
|
|
return add_embedding(user_id=user_id, qdrant_doc_id=vector_id, source=file.file_name) |
|
|
except Exception as e: |
|
|
log.error(f"Error adding embedding for file ID {file_id}: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def get_user_files_compat(user_id: str) -> List[str]: |
|
|
""" |
|
|
Get user files as list of filenames (sq_db.py compatibility). |
|
|
""" |
|
|
files_data = get_user_files(user_id=user_id) |
|
|
return [f['file_name'] for f in files_data] |
|
|
|
|
|
|
|
|
|
|
|
init_database() |
|
|
|