Spaces:
Running
Running
| import os | |
| import tempfile | |
| import logging | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import whisper | |
| import librosa | |
| import asyncio | |
| import random | |
| import requests | |
| from datetime import datetime, timedelta | |
| from typing import List, Dict, Any, Optional | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification | |
| from speechbrain.inference.classifiers import EncoderClassifier | |
| import torchaudio | |
| import json | |
| from pydantic import BaseModel | |
| from supabase import create_client, Client | |
| # Setup logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| """Centralized model management for all ML models.""" | |
| _instance = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(ModelManager, cls).__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: | |
| return | |
| self._initialized = True | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| self.emotion_model = None | |
| self.whisper_model = None | |
| self.text_tokenizer = None | |
| self.text_model = None | |
| self.speechbrain_model = None | |
| # Model paths | |
| self.MODEL_PATHS = { | |
| 'whisper_model': 'base', | |
| 'text_model': 'emotion-distilbert-model', | |
| 'speechbrain_model': 'speechbrain/emotion-recognition-wav2vec2-IEMOCAP' | |
| } | |
| # Constants | |
| self.EMOTIONS = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"] | |
| self.SAMPLE_RATE = 16000 | |
| self.TEXT_EMOTIONS = ["sadness", "joy", "love", "anger", "fear", "surprise"] | |
| # SpeechBrain emotion mapping | |
| self.SPEECHBRAIN_EMOTION_MAP = { | |
| 'neu': 'Neutral', | |
| 'hap': 'Happy', | |
| 'sad': 'Sad', | |
| 'ang': 'Angry', | |
| 'fea': 'Fear', | |
| 'dis': 'Disgust', | |
| 'sur': 'Surprise' | |
| } | |
| def load_all_models(self): | |
| """Load all required models.""" | |
| try: | |
| logger.info("Starting to load all models...") | |
| self._load_emotion_model() | |
| self._load_whisper_model() | |
| self._load_text_models() | |
| self._load_speechbrain_model() | |
| logger.info("All models loaded successfully!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading models: {str(e)}") | |
| raise | |
| def _load_emotion_model(self): | |
| """Use DeepFace for emotion recognition.""" | |
| try: | |
| logger.info("Loading DeepFace for emotion recognition...") | |
| from deepface import DeepFace | |
| self.emotion_model = DeepFace | |
| logger.info("DeepFace loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize DeepFace: {str(e)}") | |
| raise | |
| def _load_whisper_model(self): | |
| """Load the Whisper speech-to-text model.""" | |
| try: | |
| logger.info("Loading Whisper model...") | |
| self.whisper_model = whisper.load_model(self.MODEL_PATHS['whisper_model']) | |
| logger.info("Whisper model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load Whisper model: {str(e)}") | |
| raise | |
| def _load_text_models(self): | |
| """Load the text emotion classification model and tokenizer.""" | |
| try: | |
| logger.info("Loading text emotion model...") | |
| model_path = self.MODEL_PATHS['text_model'] | |
| # Try to load from local path first, then from HuggingFace Hub | |
| if os.path.exists(model_path): | |
| self.text_tokenizer = DistilBertTokenizerFast.from_pretrained(model_path) | |
| self.text_model = DistilBertForSequenceClassification.from_pretrained(model_path) | |
| else: | |
| # Use a public emotion model from HuggingFace | |
| logger.info("Local model not found, using HuggingFace model...") | |
| self.text_tokenizer = DistilBertTokenizerFast.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") | |
| self.text_model = DistilBertForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") | |
| self.text_model.eval() | |
| logger.info("Text models loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load text models: {str(e)}") | |
| raise | |
| def _load_speechbrain_model(self): | |
| """Load SpeechBrain emotion recognition model.""" | |
| try: | |
| logger.info("Loading SpeechBrain emotion recognition model...") | |
| self.speechbrain_model = EncoderClassifier.from_hparams( | |
| source=self.MODEL_PATHS['speechbrain_model'], | |
| savedir="pretrained_models/emotion-recognition-wav2vec2-IEMOCAP", | |
| run_opts={"device": "cpu"} | |
| ) | |
| logger.info("SpeechBrain emotion recognition model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load SpeechBrain model: {str(e)}") | |
| raise | |
| def get_emotion_model(self): | |
| if self.emotion_model is None: | |
| self._load_emotion_model() | |
| return self.emotion_model | |
| def get_whisper_model(self): | |
| if self.whisper_model is None: | |
| self._load_whisper_model() | |
| return self.whisper_model | |
| def get_text_models(self): | |
| if self.text_model is None or self.text_tokenizer is None: | |
| self._load_text_models() | |
| return self.text_tokenizer, self.text_model | |
| def get_speechbrain_model(self): | |
| if self.speechbrain_model is None: | |
| self._load_speechbrain_model() | |
| return self.speechbrain_model | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Manan ML API - Emotion Recognition") | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["*"] | |
| ) | |
| # Initialize model manager | |
| model_manager = ModelManager() | |
| # Image transformation pipeline | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| async def startup_event(): | |
| """Initialize all models when the application starts.""" | |
| try: | |
| logger.info("Starting model initialization...") | |
| model_manager.load_all_models() | |
| logger.info("All models initialized successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize models: {str(e)}") | |
| # Don't raise - let the app start and load models on demand | |
| async def root(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "running", | |
| "message": "Manan ML API is running!", | |
| "endpoints": [ | |
| "/pred_face - Face emotion prediction", | |
| "/predict_audio_batch - Voice emotion prediction", | |
| "/predict_text/ - Text emotion prediction" | |
| ] | |
| } | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return {"status": "healthy", "device": str(model_manager.device)} | |
| # Helper function for SpeechBrain prediction | |
| def predict_emotion_speechbrain(audio_path: str) -> Dict[str, Any]: | |
| """Predict emotion from audio using SpeechBrain.""" | |
| try: | |
| speechbrain_model = model_manager.get_speechbrain_model() | |
| signal, sr = torchaudio.load(audio_path) | |
| if sr != 16000: | |
| resampler = torchaudio.transforms.Resample(sr, 16000) | |
| signal = resampler(signal) | |
| if signal.dim() == 1: | |
| signal = signal.unsqueeze(0) | |
| elif signal.dim() == 3: | |
| signal = signal.squeeze(1) | |
| device = next(speechbrain_model.mods.wav2vec2.parameters()).device | |
| signal = signal.to(device) | |
| with torch.no_grad(): | |
| feats = speechbrain_model.mods.wav2vec2(signal) | |
| pooled = speechbrain_model.mods.avg_pool(feats) | |
| out = speechbrain_model.mods.output_mlp(pooled) | |
| out_prob = speechbrain_model.hparams.softmax(out) | |
| score, index = torch.max(out_prob, dim=-1) | |
| predicted_emotion = speechbrain_model.hparams.label_encoder.decode_ndim(index.cpu()) | |
| if isinstance(predicted_emotion, list): | |
| if isinstance(predicted_emotion[0], list): | |
| emotion_key = str(predicted_emotion[0][0]).lower()[:3] | |
| else: | |
| emotion_key = str(predicted_emotion[0]).lower()[:3] | |
| else: | |
| emotion_key = str(predicted_emotion).lower()[:3] | |
| emotion = model_manager.SPEECHBRAIN_EMOTION_MAP.get(emotion_key, 'Neutral') | |
| probs = out_prob[0].detach().cpu().numpy() | |
| if probs.ndim > 1: | |
| probs = probs.flatten() | |
| all_emotions = speechbrain_model.hparams.label_encoder.decode_ndim( | |
| torch.arange(len(probs)) | |
| ) | |
| prob_dict = {} | |
| for i in range(len(probs)): | |
| if i < len(all_emotions): | |
| if isinstance(all_emotions[i], list): | |
| key = str(all_emotions[i][0]).lower()[:3] | |
| else: | |
| key = str(all_emotions[i]).lower()[:3] | |
| emotion_name = model_manager.SPEECHBRAIN_EMOTION_MAP.get(key, f'emotion_{i}') | |
| prob_dict[emotion_name] = float(probs[i]) | |
| confidence = float(score[0]) | |
| return { | |
| 'emotion': emotion, | |
| 'confidence': confidence, | |
| 'probabilities': prob_dict | |
| } | |
| except Exception as e: | |
| logger.error(f"Error predicting emotion with SpeechBrain: {str(e)}") | |
| raise | |
| def transcribe_audio(audio_path: str) -> str: | |
| """Transcribe audio to text using Whisper.""" | |
| try: | |
| result = model_manager.whisper_model.transcribe(audio_path) | |
| return result["text"].strip() | |
| except Exception as e: | |
| logger.error(f"Error in audio transcription: {str(e)}") | |
| return "" | |
| # ============== API ENDPOINTS ============== | |
| async def predict_face_emotion( | |
| files: List[UploadFile] = File(...), | |
| questions: str = Form(None) | |
| ): | |
| """Predict emotions from face images using DeepFace.""" | |
| from deepface import DeepFace | |
| logger.info(f"Received {len(files)} files for face prediction") | |
| if not files: | |
| raise HTTPException(status_code=400, detail="No files provided") | |
| temp_files = [] | |
| try: | |
| questions_data = {} | |
| question_count = 0 | |
| if questions: | |
| try: | |
| questions_data = json.loads(questions) | |
| question_count = len(questions_data) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid questions JSON format.") | |
| else: | |
| question_count = 3 | |
| questions_data = {str(i): {"text": f"Question {i+1}", "imageCount": 1} for i in range(question_count)} | |
| question_files = {str(i): [] for i in range(question_count)} | |
| for file in files: | |
| if '_' in file.filename and file.filename.startswith('q'): | |
| try: | |
| q_idx = file.filename.split('_')[0][1:] | |
| if q_idx in question_files: | |
| question_files[q_idx].append(file) | |
| except Exception as e: | |
| logger.warning(f"Skipping file {file.filename}: {e}") | |
| results = [] | |
| for q_idx, q_files in question_files.items(): | |
| if not q_files: | |
| results.append({ | |
| "emotion": "Unknown", | |
| "probabilities": {e: 0.0 for e in model_manager.EMOTIONS} | |
| }) | |
| continue | |
| probs_list = [] | |
| for file in q_files: | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: | |
| content = await file.read() | |
| tmp.write(content) | |
| temp_path = tmp.name | |
| temp_files.append(temp_path) | |
| analysis = DeepFace.analyze( | |
| img_path=temp_path, | |
| actions=['emotion'], | |
| enforce_detection=False, | |
| silent=True | |
| ) | |
| if isinstance(analysis, list): | |
| analysis = analysis[0] | |
| emotion_scores = analysis.get('emotion', {}) | |
| dominant_emotion = analysis.get('dominant_emotion', 'neutral') | |
| normalized_probs = {} | |
| for emo in model_manager.EMOTIONS: | |
| key = emo.lower() | |
| normalized_probs[emo] = emotion_scores.get(key, 0.0) / 100.0 | |
| probs_list.append(normalized_probs) | |
| except Exception as e: | |
| logger.error(f"Error processing {file.filename}: {e}") | |
| if probs_list: | |
| avg_probs = {} | |
| for emo in model_manager.EMOTIONS: | |
| avg_probs[emo] = sum(p.get(emo, 0) for p in probs_list) / len(probs_list) | |
| dominant_emotion = max(avg_probs, key=avg_probs.get) | |
| results.append({ | |
| "emotion": dominant_emotion, | |
| "probabilities": avg_probs | |
| }) | |
| else: | |
| results.append({ | |
| "emotion": "Unknown", | |
| "probabilities": {e: 0.0 for e in model_manager.EMOTIONS} | |
| }) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error in face emotion prediction: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| for file_path in temp_files: | |
| try: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| except Exception as e: | |
| logger.warning(f"Failed to delete temp file {file_path}: {e}") | |
| async def predict_audio_batch(files: List[UploadFile] = File(...)): | |
| """Predict emotions from multiple audio files using SpeechBrain.""" | |
| logger.info(f"Received {len(files)} audio files for prediction") | |
| if not files: | |
| raise HTTPException(status_code=400, detail="No audio files provided") | |
| temp_files = [] | |
| results = [] | |
| try: | |
| for file in files: | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| content = await file.read() | |
| tmp.write(content) | |
| temp_path = tmp.name | |
| temp_files.append(temp_path) | |
| prediction = predict_emotion_speechbrain(temp_path) | |
| results.append(prediction) | |
| logger.info(f"Predicted emotion for {file.filename}: {prediction['emotion']}") | |
| except Exception as e: | |
| logger.error(f"Error processing {file.filename}: {e}") | |
| results.append({ | |
| 'emotion': 'Unknown', | |
| 'confidence': 0.0, | |
| 'probabilities': {}, | |
| 'error': str(e) | |
| }) | |
| return {'status': 'success', 'results': results} | |
| except Exception as e: | |
| logger.error(f"Error in audio batch prediction: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| for file_path in temp_files: | |
| try: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| except Exception as e: | |
| logger.warning(f"Failed to delete temp file {file_path}: {e}") | |
| async def predict_text_emotion(files: List[UploadFile] = File(...)): | |
| """Transcribe audio and predict text emotion.""" | |
| logger.info(f"Received {len(files)} audio files for text prediction") | |
| if not files: | |
| raise HTTPException(status_code=400, detail="No audio files provided") | |
| temp_files = [] | |
| results = [] | |
| try: | |
| tokenizer, text_model = model_manager.get_text_models() | |
| whisper_model = model_manager.get_whisper_model() | |
| for file in files: | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| content = await file.read() | |
| tmp.write(content) | |
| temp_path = tmp.name | |
| temp_files.append(temp_path) | |
| # Transcribe | |
| transcription = whisper_model.transcribe(temp_path) | |
| transcript = transcription["text"].strip() | |
| logger.info(f"Transcribed: {transcript}") | |
| if not transcript: | |
| results.append({ | |
| 'transcript': '', | |
| 'emotion': 'neutral', | |
| 'confidence': 0.0, | |
| 'probabilities': {} | |
| }) | |
| continue | |
| # Predict emotion from text | |
| inputs = tokenizer( | |
| transcript, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=128, | |
| padding=True | |
| ) | |
| with torch.no_grad(): | |
| outputs = text_model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=1)[0] | |
| # Get emotion labels | |
| emotion_labels = model_manager.TEXT_EMOTIONS | |
| if hasattr(text_model.config, 'id2label'): | |
| emotion_labels = [text_model.config.id2label[i] for i in range(len(probs))] | |
| prob_dict = {emotion_labels[i]: float(probs[i]) for i in range(len(probs))} | |
| predicted_idx = torch.argmax(probs).item() | |
| predicted_emotion = emotion_labels[predicted_idx] | |
| confidence = float(probs[predicted_idx]) | |
| results.append({ | |
| 'transcript': transcript, | |
| 'emotion': predicted_emotion, | |
| 'confidence': confidence, | |
| 'probabilities': prob_dict | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error processing {file.filename}: {e}") | |
| results.append({ | |
| 'transcript': '', | |
| 'emotion': 'unknown', | |
| 'confidence': 0.0, | |
| 'error': str(e) | |
| }) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error in text prediction: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| for file_path in temp_files: | |
| try: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| except Exception as e: | |
| logger.warning(f"Failed to delete temp file {file_path}: {e}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |
| # ============================================================================= | |
| # AUTHENTICATION & USER MANAGEMENT ENDPOINTS | |
| # ============================================================================= | |
| # Supabase configuration | |
| SUPABASE_URL = os.environ.get("SUPABASE_URL") | |
| SUPABASE_KEY = os.environ.get("SUPABASE_KEY") | |
| BREVO_API_KEY = os.environ.get("EMAIL_API") | |
| # Initialize Supabase client | |
| supabase: Client = None | |
| try: | |
| if SUPABASE_URL and SUPABASE_KEY: | |
| supabase = create_client(SUPABASE_URL, SUPABASE_KEY) | |
| logger.info("Supabase client initialized successfully") | |
| else: | |
| logger.warning("Supabase credentials not found in environment variables") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Supabase: {str(e)}") | |
| # OTP storage (in-memory, resets on restart) | |
| otp_store = {} | |
| # ------------------------------- | |
| # Request Models for Auth | |
| # ------------------------------- | |
| class OTPRequest(BaseModel): | |
| email: str | |
| class OTPVerifyRequest(BaseModel): | |
| email: str | |
| otp: str | |
| class RegisterUserRequest(BaseModel): | |
| name: str | |
| email: str | |
| password: str | |
| class SendEmotionRequest(BaseModel): | |
| email: str | |
| emotion: str | |
| class UpdateProfilePicRequest(BaseModel): | |
| email: str | |
| profile_pic_url: str | |
| class UpdateProfileRequest(BaseModel): | |
| email: str | |
| name: str = None | |
| age: int = None | |
| phone: str = None | |
| # ------------------------------- | |
| # Helper Function - Send Email via Brevo | |
| # ------------------------------- | |
| def send_email_brevo(to_email: str, otp: str): | |
| url = "https://api.brevo.com/v3/smtp/email" | |
| headers = { | |
| "accept": "application/json", | |
| "api-key": BREVO_API_KEY, | |
| "content-type": "application/json", | |
| } | |
| data = { | |
| "sender": {"name": "मनन", "email": "noreplymanan@gmail.com"}, | |
| "to": [{"email": to_email}], | |
| "subject": "Your Manan OTP Code", | |
| "htmlContent": f""" | |
| <html> | |
| <body> | |
| <h2>Your OTP Code</h2> | |
| <p>Your verification code is: <strong>{otp}</strong></p> | |
| <p>This code will expire in 5 minutes.</p> | |
| </body> | |
| </html> | |
| """, | |
| } | |
| response = requests.post(url, headers=headers, json=data) | |
| if response.status_code not in [200, 201]: | |
| raise HTTPException(status_code=500, detail=f"Email sending failed: {response.text}") | |
| # ------------------------------- | |
| # Auth Endpoints | |
| # ------------------------------- | |
| def get_status(): | |
| try: | |
| if supabase is None: | |
| return {"status": "Supabase not configured", "supabase_data_retrieved": False} | |
| response = supabase.table('users').select("id").limit(1).execute() | |
| return {"status": "Connection successful", "supabase_data_retrieved": len(response.data) > 0} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Connection failed: {str(e)}") | |
| def send_otp(req: OTPRequest): | |
| try: | |
| otp = str(random.randint(100000, 999999)) | |
| otp_store[req.email] = {"otp": otp, "timestamp": datetime.utcnow()} | |
| logger.info(f"OTP generated for {req.email}") | |
| # Send email | |
| send_email_brevo(req.email, otp) | |
| return {"message": f"OTP sent successfully to {req.email}"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error sending OTP: {str(e)}") | |
| def check_otp(req: OTPVerifyRequest): | |
| try: | |
| if req.email not in otp_store: | |
| raise HTTPException(status_code=404, detail="No OTP found for this email") | |
| stored_data = otp_store[req.email] | |
| stored_otp = stored_data["otp"] | |
| timestamp = stored_data["timestamp"] | |
| # Expiry check (5 minutes) | |
| if datetime.utcnow() - timestamp > timedelta(minutes=5): | |
| del otp_store[req.email] | |
| raise HTTPException(status_code=400, detail="OTP expired") | |
| if req.otp == stored_otp: | |
| # Mark as verified instead of deleting | |
| otp_store[req.email]["verified"] = True | |
| otp_store[req.email]["otp"] = None | |
| return {"verified": True} | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid OTP") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error verifying OTP: {str(e)}") | |
| def register_user(req: RegisterUserRequest): | |
| try: | |
| if supabase is None: | |
| raise HTTPException(status_code=500, detail="Supabase not configured") | |
| # Check OTP verification | |
| if req.email not in otp_store or not otp_store[req.email].get("verified", False): | |
| raise HTTPException(status_code=403, detail="Email not verified via OTP") | |
| # Insert into Supabase | |
| response = supabase.table("users").insert({ | |
| "name": req.name, | |
| "email": req.email, | |
| "password": req.password | |
| }).execute() | |
| # Cleanup | |
| del otp_store[req.email] | |
| return {"message": "User registered successfully", "data": response.data} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error registering user: {str(e)}") | |
| def get_profile(email: str): | |
| try: | |
| if supabase is None: | |
| raise HTTPException(status_code=500, detail="Supabase not configured") | |
| response = supabase.table("users").select("*").eq("email", email).single().execute() | |
| if not response.data: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| # Remove password from response for security | |
| profile_data = response.data.copy() | |
| profile_data.pop('password', None) | |
| return {"profile": profile_data} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error fetching profile: {str(e)}") | |
| def update_profile(req: UpdateProfileRequest): | |
| try: | |
| if supabase is None: | |
| raise HTTPException(status_code=500, detail="Supabase not configured") | |
| update_data = {} | |
| if req.name is not None: | |
| update_data["name"] = req.name | |
| if req.age is not None: | |
| update_data["age"] = req.age | |
| if req.phone is not None: | |
| update_data["phone"] = req.phone | |
| if not update_data: | |
| raise HTTPException(status_code=400, detail="No data provided for update") | |
| response = supabase.table("users").update(update_data).eq("email", req.email).execute() | |
| if not response.data: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| return {"message": "Profile updated successfully", "data": response.data} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error updating profile: {str(e)}") | |
| def update_profile_pic(req: UpdateProfilePicRequest): | |
| try: | |
| if supabase is None: | |
| raise HTTPException(status_code=500, detail="Supabase not configured") | |
| response = supabase.table("users").update({ | |
| "profilepic": req.profile_pic_url | |
| }).eq("email", req.email).execute() | |
| if not response.data: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| return {"message": "Profile picture updated successfully", "data": response.data} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error updating profile picture: {str(e)}") | |
| def get_score(email: str): | |
| try: | |
| if supabase is None: | |
| raise HTTPException(status_code=500, detail="Supabase not configured") | |
| user_response = supabase.table("users").select("id").eq("email", email).execute() | |
| if not user_response.data: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| user_id = user_response.data[0]["id"] | |
| predict_response = supabase.table("predict").select("prediction, timestamp") \ | |
| .eq("user_id", user_id).order("timestamp", desc=True).limit(1).execute() | |
| if not predict_response.data: | |
| raise HTTPException(status_code=404, detail="No predictions found for this user") | |
| return predict_response.data[0] | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error fetching score: {str(e)}") | |
| def send_emotion(req: SendEmotionRequest): | |
| try: | |
| if supabase is None: | |
| raise HTTPException(status_code=500, detail="Supabase not configured") | |
| user_response = supabase.table("users").select("id").eq("email", req.email).execute() | |
| if not user_response.data: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| user_id = user_response.data[0]["id"] | |
| data = { | |
| "user_id": user_id, | |
| "prediction": req.emotion, | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| response = supabase.table("predict").insert(data).execute() | |
| return {"message": "Emotion saved successfully", "data": response.data} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error saving emotion: {str(e)}") | |
| def get_emotion(email: str): | |
| try: | |
| if supabase is None: | |
| raise HTTPException(status_code=500, detail="Supabase not configured") | |
| user_response = supabase.table("users").select("id").eq("email", email).execute() | |
| if not user_response.data: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| user_id = user_response.data[0]["id"] | |
| predict_response = supabase.table("predict").select("prediction") \ | |
| .eq("user_id", user_id).order("timestamp", desc=True).limit(1).execute() | |
| if not predict_response.data: | |
| return {"emotion": "Neutral"} | |
| return {"emotion": predict_response.data[0]["prediction"]} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error fetching emotion: {str(e)}") | |
| def get_mental_health_details(email: str): | |
| try: | |
| if supabase is None: | |
| raise HTTPException(status_code=500, detail="Supabase not configured") | |
| user_response = supabase.table("users").select("id").eq("email", email).execute() | |
| if not user_response.data: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| user_id = user_response.data[0]["id"] | |
| # Get most recent prediction | |
| recent_prediction = supabase.table("predict").select("prediction, timestamp") \ | |
| .eq("user_id", user_id).order("timestamp", desc=True).limit(1).execute() | |
| current_prediction = recent_prediction.data[0]["prediction"] if recent_prediction.data else None | |
| # Calculate this week's active days | |
| week_ago = datetime.utcnow() - timedelta(days=7) | |
| weekly_predictions = supabase.table("predict").select("timestamp") \ | |
| .eq("user_id", user_id).gte("timestamp", week_ago.isoformat()).execute() | |
| if weekly_predictions.data: | |
| dates = set() | |
| for p in weekly_predictions.data: | |
| if p.get("timestamp"): | |
| try: | |
| ts_str = p["timestamp"].replace('Z', '+00:00') | |
| dt = datetime.fromisoformat(ts_str) | |
| dates.add(dt.date()) | |
| except (ValueError, AttributeError): | |
| try: | |
| ts_str = p["timestamp"].split('+')[0].split('Z')[0] | |
| dt = datetime.strptime(ts_str, '%Y-%m-%dT%H:%M:%S.%f') | |
| dates.add(dt.date()) | |
| except: | |
| try: | |
| dt = datetime.strptime(ts_str, '%Y-%m-%dT%H:%M:%S') | |
| dates.add(dt.date()) | |
| except: | |
| pass | |
| active_days = len(dates) | |
| else: | |
| active_days = 0 | |
| # Total conversations | |
| total_conversations = supabase.table("predict").select("*", count="exact") \ | |
| .eq("user_id", user_id).execute() | |
| total_count = total_conversations.count if total_conversations else 0 | |
| return { | |
| "current_prediction": current_prediction, | |
| "active_days_this_week": active_days, | |
| "total_conversations": total_count | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error fetching mental health details: {str(e)}") | |
| # ============================================================================= | |
| # DAILY EMOTION SCORE CALCULATION | |
| # ============================================================================= | |
| # Valence Mapping for emotions | |
| EMO_VALENCE = { | |
| "Angry": -0.80, | |
| "Disgust": -0.60, | |
| "Fear": -0.70, | |
| "Happy": 0.90, | |
| "Sad": -0.90, | |
| "Surprise": 0.20, | |
| "Neutral": 0.0, | |
| # text emotions | |
| "sadness": -0.90, | |
| "joy": 0.90, | |
| "love": 0.80, | |
| "anger": -0.80, | |
| "fear": -0.70, | |
| "surprise": 0.20 | |
| } | |
| def valence_from_probabilities(probabilities: Dict[str, float]) -> float: | |
| if not probabilities: | |
| return 0.0 | |
| v = 0.0 | |
| total = sum(probabilities.values()) or 1.0 | |
| for emo, p in probabilities.items(): | |
| key = emo if emo in EMO_VALENCE else emo.capitalize() | |
| v += p * EMO_VALENCE.get(key, 0.0) | |
| return v / total | |
| def valence_to_score(v: float) -> float: | |
| return (v + 1) / 2 * 100 # [-1..1] → [0..100] | |
| class EmotionItem(BaseModel): | |
| emotion: str | |
| probabilities: Optional[Dict[str, float]] = None | |
| confidence: Optional[float] = 1.0 | |
| class ScoreRequest(BaseModel): | |
| face_results: Optional[List[EmotionItem]] = [] | |
| audio_results: Optional[List[EmotionItem]] = [] | |
| text_results: Optional[List[EmotionItem]] = [] | |
| def calculate_day_score(payload: ScoreRequest): | |
| """Calculate weighted day score from face, audio, and text emotions.""" | |
| source_weights = { | |
| "face": 0.4, | |
| "audio": 0.35, | |
| "text": 0.25 | |
| } | |
| accum_num = 0.0 | |
| accum_den = 0.0 | |
| breakdown = {"face": [], "audio": [], "text": []} | |
| # FACE | |
| for item in payload.face_results: | |
| v = valence_from_probabilities(item.probabilities) \ | |
| if item.probabilities else EMO_VALENCE.get(item.emotion, 0.0) | |
| score = valence_to_score(v) | |
| w = source_weights["face"] * (item.confidence or 1.0) | |
| accum_num += score * w | |
| accum_den += w | |
| breakdown["face"].append({ | |
| "emotion": item.emotion, | |
| "valence": v, | |
| "score": score, | |
| "weight": w | |
| }) | |
| # AUDIO | |
| for item in payload.audio_results: | |
| v = valence_from_probabilities(item.probabilities) \ | |
| if item.probabilities else EMO_VALENCE.get(item.emotion, 0.0) | |
| score = valence_to_score(v) | |
| w = source_weights["audio"] * (item.confidence or 1.0) | |
| accum_num += score * w | |
| accum_den += w | |
| breakdown["audio"].append({ | |
| "emotion": item.emotion, | |
| "confidence": item.confidence, | |
| "valence": v, | |
| "score": score, | |
| "weight": w | |
| }) | |
| # TEXT | |
| for item in payload.text_results: | |
| v = valence_from_probabilities(item.probabilities) \ | |
| if item.probabilities else EMO_VALENCE.get(item.emotion.lower(), 0.0) | |
| score = valence_to_score(v) | |
| w = source_weights["text"] * (item.confidence or 1.0) | |
| accum_num += score * w | |
| accum_den += w | |
| breakdown["text"].append({ | |
| "emotion": item.emotion, | |
| "confidence": item.confidence, | |
| "valence": v, | |
| "score": score, | |
| "weight": w | |
| }) | |
| final_score = accum_num / accum_den if accum_den > 0 else None | |
| return { | |
| "day_score": final_score, | |
| "breakdown": breakdown, | |
| "numerator": accum_num, | |
| "denominator": accum_den | |
| } | |