canary_aed_streaming / app /session_utils.py
Archime's picture
correctly GPU ABORT
0c03412
import os
import json
import uuid
import shutil
from datetime import datetime
from app.logger_config import logger as logging
import gradio as gr
# TMP_DIR = "/tmp/canary_aed_streaming"
TMP_DIR = os.getenv("TMP_DIR", "/tmp/canary_aed_streaming")
ACTIVE_SESSIONS_HASH_FILE = os.path.join(TMP_DIR, "active_session_hash_code.json")
ACTIVE_STREAM_FLAG="stream_active_"
ACTIVE_TASK_FLAG="task_active_"
NAME_FOLDER_CHUNKS="chunks_"
# ---------------------------
# Helper to manage the JSON
# ---------------------------
def _read_session_hash_code():
if not os.path.exists(ACTIVE_SESSIONS_HASH_FILE):
return {}
try:
with open(ACTIVE_SESSIONS_HASH_FILE, "r") as f:
return json.load(f)
except Exception:
return {}
def _write_session_hash_code(data):
os.makedirs(os.path.dirname(ACTIVE_SESSIONS_HASH_FILE), exist_ok=True)
with open(ACTIVE_SESSIONS_HASH_FILE, "w") as f:
json.dump(data, f, indent=2)
# ---------------------------
# LOAD
# ---------------------------
def on_load(request: gr.Request):
"""Called when a new visitor opens the app."""
session_hash_code = request.session_hash
sessions = _read_session_hash_code()
sessions[session_hash_code] = {
"session_hash_code": session_hash_code,
"file": "",
"start_time": datetime.utcnow().strftime("%H:%M:%S"),
"status": "active",
}
_write_session_hash_code(sessions)
logging.info(f"[{session_hash_code}] session_hash_code registered (on_load).")
return session_hash_code, session_hash_code # can be used as gr.State + display
# ---------------------------
# UNLOAD
# ---------------------------
def on_unload(request: gr.Request):
"""Called when the visitor closes or refreshes the app."""
session_hash_code = request.session_hash
sessions = _read_session_hash_code()
if session_hash_code in sessions:
sessions.pop(session_hash_code)
_write_session_hash_code(sessions)
remove_session_hash_code_data(session_hash_code)
unregister_session_hash_code_hash(session_hash_code)
logging.info(f"[{session_hash_code}] session_hash_code removed (on_unload).")
else:
logging.info(f"[{session_hash_code}] No active session_hash_code found to remove.")
def ensure_tmp_dir():
"""Ensures the base temporary directory exists."""
try:
os.makedirs(TMP_DIR, exist_ok=True)
except Exception as e:
logging.error(f"Failed to create tmp directory {TMP_DIR}: {e}")
def reset_all_active_sessions():
"""Removes all temporary session_hash_code files and folders at startup."""
ensure_tmp_dir()
try:
# --- Remove active session_hash_codes file ---
if os.path.exists(ACTIVE_SESSIONS_HASH_FILE):
os.remove(ACTIVE_SESSIONS_HASH_FILE)
logging.info("Active session_hash_codes file reset at startup.")
else:
logging.debug("No active session_hash_codes file found to reset.")
# --- Clean all flag files (stream + transcribe) ---
for f in os.listdir(TMP_DIR):
if (
f.startswith(f"{ACTIVE_TASK_FLAG}")
or f.startswith(f"{ACTIVE_STREAM_FLAG}")
) and f.endswith(".txt"):
path = os.path.join(TMP_DIR, f)
try:
os.remove(path)
logging.debug(f"Removed leftover flag file: {f}")
except Exception as e:
logging.warning(f"Failed to remove flag file {f}: {e}")
# --- Clean chunk directories ---
for name in os.listdir(TMP_DIR):
path = os.path.join(TMP_DIR, name)
if os.path.isdir(path) and name.startswith(f"{NAME_FOLDER_CHUNKS}"):
try:
shutil.rmtree(path)
logging.debug(f"Removed leftover chunk folder: {name}")
except Exception as e:
logging.warning(f"Failed to remove chunk folder {name}: {e}")
logging.info("Temporary session cleanup completed successfully.")
except Exception as e:
logging.error(f"Error resetting active session_hash_codes: {e}")
def remove_session_hash_code_data(session_hash_code: str):
"""Removes all temporary files and data related to a specific session_hash_code."""
if not session_hash_code:
logging.warning("reset_session() called without a valid session_hash_code.")
return
try:
# --- Remove session_hash_code from active_sessions.json ---
if os.path.exists(ACTIVE_SESSIONS_HASH_FILE):
try:
with open(ACTIVE_SESSIONS_HASH_FILE, "r") as f:
data = json.load(f)
if session_hash_code in data:
data.pop(session_hash_code)
with open(ACTIVE_SESSIONS_HASH_FILE, "w") as f:
json.dump(data, f, indent=2)
logging.debug(f"[{session_hash_code}] Removed from {ACTIVE_SESSIONS_HASH_FILE}.")
except Exception as e:
logging.warning(f"[{session_hash_code}] Failed to update {ACTIVE_SESSIONS_HASH_FILE}: {e}")
# --- Define all possible session_hash_code file patterns ---
# --- Remove all temporary files ---
remove_active_task_flag_file(session_hash_code)
remove_active_stream_flag_file(session_hash_code)
remove_chunk_folder(session_hash_code)
logging.info(f"[{session_hash_code}] session_hash_code fully reset.")
except Exception as e:
logging.error(f"[{session_hash_code}] Error during reset_session: {e}")
def register_session_hash_code(session_hash_code: str, filepath: str):
"""Registers a new session_hash_code."""
ensure_tmp_dir()
data = {}
if os.path.exists(ACTIVE_SESSIONS_HASH_FILE):
with open(ACTIVE_SESSIONS_HASH_FILE, "r") as f:
try:
data = json.load(f)
except Exception:
data = {}
data[session_hash_code] = {
"session_hash_code": session_hash_code,
"file": filepath,
"start_time": datetime.utcnow().strftime("%H:%M:%S"),
"status": "active",
}
with open(ACTIVE_SESSIONS_HASH_FILE, "w") as f:
json.dump(data, f)
logging.debug(f"[{session_hash_code}] session_hash_code registered in active_sessions.json.")
def unregister_session_hash_code_hash(session_hash_code: str):
"""Removes a session_hash_code from the registry."""
if not os.path.exists(ACTIVE_SESSIONS_HASH_FILE):
return
try:
with open(ACTIVE_SESSIONS_HASH_FILE, "r") as f:
data = json.load(f)
if session_hash_code in data:
data.pop(session_hash_code)
with open(ACTIVE_SESSIONS_HASH_FILE, "w") as f:
json.dump(data, f)
logging.debug(f"[{session_hash_code}] session_hash_code unregistered.")
except Exception as e:
logging.error(f"[{session_hash_code}] Error unregistering session_hash_code: {e}")
def get_active_session_hashes():
"""Returns active session_hash_codes as a list of rows for the DataFrame."""
if not os.path.exists(ACTIVE_SESSIONS_HASH_FILE):
return []
try:
with open(ACTIVE_SESSIONS_HASH_FILE, "r") as f:
data = json.load(f)
rows = [
[
s.get("session_hash_code", ""),
s.get("file", ""),
s.get("start_time", ""),
s.get("status", ""),
]
for s in data.values()
]
return rows
except Exception as e:
logging.error(f"Error reading active session_hash_codes: {e}")
return []
def get_active_task_flag_file(session_hash_code: str):
return os.path.join(TMP_DIR, f"{ACTIVE_TASK_FLAG}{session_hash_code}.txt")
def get_active_stream_flag_file(session_hash_code: str):
return os.path.join(TMP_DIR, f"{ACTIVE_STREAM_FLAG}{session_hash_code}.txt")
def remove_active_stream_flag_file(session_hash_code: str):
fname = os.path.join(TMP_DIR, f"{ACTIVE_STREAM_FLAG}{session_hash_code}.txt")
if os.path.exists(fname):
try:
os.remove(fname)
logging.debug(f"[{session_hash_code}] Removed file: {fname}")
except Exception as e:
logging.warning(f"[{session_hash_code}] Failed to remove file {fname}: {e}")
def remove_active_task_flag_file(session_hash_code: str):
fname = os.path.join(TMP_DIR, f"{ACTIVE_TASK_FLAG}{session_hash_code}.txt")
if os.path.exists(fname):
try:
os.remove(fname)
logging.debug(f"[{session_hash_code}] Removed file: {fname}")
except Exception as e:
logging.warning(f"[{session_hash_code}] Failed to remove file {fname}: {e}")
def remove_chunk_folder(session_hash_code: str) :
# --- Remove chunk folder if exists ---
chunk_dir = os.path.join(TMP_DIR, f"chunks_{session_hash_code}")
if os.path.isdir(chunk_dir):
try:
shutil.rmtree(chunk_dir)
logging.debug(f"[{session_hash_code}] Removed chunk folder: chunks_{session_hash_code}")
except Exception as e:
logging.warning(f"[{session_hash_code}] Failed to remove chunk folder: {e}")
def get_session_hashe_chunks_dir(session_hash_code: str):
return os.path.join(TMP_DIR, f"{NAME_FOLDER_CHUNKS}{session_hash_code}")