Spaces:
Paused
Paused
| import glob | |
| import os | |
| import pathlib | |
| import shutil | |
| from typing import Union | |
| from fastapi import BackgroundTasks, Depends, Path, Request, UploadFile | |
| from fastapi.params import File | |
| from fastapi.responses import FileResponse, StreamingResponse | |
| from loguru import logger | |
| from app.config import config | |
| from app.controllers import base | |
| from app.controllers.manager.memory_manager import InMemoryTaskManager | |
| from app.controllers.manager.redis_manager import RedisTaskManager | |
| from app.controllers.v1.base import new_router | |
| from app.models.exception import HttpException | |
| from app.models.schema import ( | |
| AudioRequest, | |
| BgmRetrieveResponse, | |
| BgmUploadResponse, | |
| SubtitleRequest, | |
| TaskDeletionResponse, | |
| TaskQueryRequest, | |
| TaskQueryResponse, | |
| TaskResponse, | |
| TaskVideoRequest, | |
| ) | |
| from app.services import state as sm | |
| from app.services import task as tm | |
| from app.utils import utils | |
| # 认证依赖项 | |
| router = new_router(dependencies=[Depends(base.verify_token)]) | |
| # router = new_router() | |
| _enable_redis = config.app.get("enable_redis", False) | |
| _redis_host = config.app.get("redis_host", "localhost") | |
| _redis_port = config.app.get("redis_port", 6379) | |
| _redis_db = config.app.get("redis_db", 0) | |
| _redis_password = config.app.get("redis_password", None) | |
| _max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5) | |
| redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}" | |
| # 根据配置选择合适的任务管理器 | |
| if _enable_redis: | |
| task_manager = RedisTaskManager( | |
| max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url | |
| ) | |
| else: | |
| task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks) | |
| def create_video( | |
| background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest | |
| ): | |
| return create_task(request, body, stop_at="video") | |
| def create_subtitle( | |
| background_tasks: BackgroundTasks, request: Request, body: SubtitleRequest | |
| ): | |
| return create_task(request, body, stop_at="subtitle") | |
| def create_audio( | |
| background_tasks: BackgroundTasks, request: Request, body: AudioRequest | |
| ): | |
| return create_task(request, body, stop_at="audio") | |
| def create_task( | |
| request: Request, | |
| body: Union[TaskVideoRequest, SubtitleRequest, AudioRequest], | |
| stop_at: str, | |
| ): | |
| task_id = utils.get_uuid() | |
| request_id = base.get_task_id(request) | |
| try: | |
| task = { | |
| "task_id": task_id, | |
| "request_id": request_id, | |
| "params": "[REDACTED]", # 不在日志中暴露完整参数 | |
| } | |
| sm.state.update_task(task_id) | |
| task_manager.add_task(tm.start, task_id=task_id, params=body, stop_at=stop_at) | |
| logger.success(f"Task created: {utils.to_json(task)}") | |
| return utils.get_response(200, {"task_id": task_id, "request_id": request_id}) | |
| except ValueError as e: | |
| raise HttpException( | |
| task_id=task_id, status_code=400, message=f"Invalid request parameters" | |
| ) | |
| from fastapi import Query | |
| def get_all_tasks(request: Request, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1)): | |
| request_id = base.get_task_id(request) | |
| tasks, total = sm.state.get_all_tasks(page, page_size) | |
| response = { | |
| "tasks": tasks, | |
| "total": total, | |
| "page": page, | |
| "page_size": page_size, | |
| } | |
| return utils.get_response(200, response) | |
| def get_task( | |
| request: Request, | |
| task_id: str = Path(..., description="Task ID"), | |
| query: TaskQueryRequest = Depends(), | |
| ): | |
| endpoint = config.app.get("endpoint", "") | |
| if not endpoint: | |
| endpoint = str(request.base_url) | |
| endpoint = endpoint.rstrip("/") | |
| request_id = base.get_task_id(request) | |
| task = sm.state.get_task(task_id) | |
| if task: | |
| task_dir = utils.task_dir() | |
| def file_to_uri(file): | |
| if not file.startswith(endpoint): | |
| _uri_path = v.replace(task_dir, "tasks").replace("\\", "/") | |
| _uri_path = f"{endpoint}/{_uri_path}" | |
| else: | |
| _uri_path = file | |
| return _uri_path | |
| if "videos" in task: | |
| videos = task["videos"] | |
| urls = [] | |
| for v in videos: | |
| urls.append(file_to_uri(v)) | |
| task["videos"] = urls | |
| if "combined_videos" in task: | |
| combined_videos = task["combined_videos"] | |
| urls = [] | |
| for v in combined_videos: | |
| urls.append(file_to_uri(v)) | |
| task["combined_videos"] = urls | |
| return utils.get_response(200, task) | |
| raise HttpException( | |
| task_id=task_id, status_code=404, message=f"{request_id}: task not found" | |
| ) | |
| def delete_video(request: Request, task_id: str = Path(..., description="Task ID")): | |
| request_id = base.get_task_id(request) | |
| task = sm.state.get_task(task_id) | |
| if task: | |
| tasks_dir = utils.task_dir() | |
| current_task_dir = os.path.join(tasks_dir, task_id) | |
| if os.path.exists(current_task_dir): | |
| shutil.rmtree(current_task_dir) | |
| sm.state.delete_task(task_id) | |
| logger.success(f"Task deleted: {task_id}") # 只记录任务ID,不暴露完整任务信息 | |
| return utils.get_response(200) | |
| raise HttpException( | |
| task_id=task_id, status_code=404, message="Task not found" | |
| ) | |
| def get_bgm_list(request: Request): | |
| suffix = "*.mp3" | |
| song_dir = utils.song_dir() | |
| files = glob.glob(os.path.join(song_dir, suffix)) | |
| bgm_list = [] | |
| for file in files: | |
| bgm_list.append( | |
| { | |
| "name": os.path.basename(file), | |
| "size": os.path.getsize(file), | |
| "file": file, | |
| } | |
| ) | |
| response = {"files": bgm_list} | |
| return utils.get_response(200, response) | |
| def upload_bgm_file(request: Request, file: UploadFile = File(...)): | |
| request_id = base.get_task_id(request) | |
| # check file ext | |
| if file.filename.endswith("mp3"): | |
| song_dir = utils.song_dir() | |
| save_path = os.path.join(song_dir, file.filename) | |
| # save file | |
| with open(save_path, "wb+") as buffer: | |
| # If the file already exists, it will be overwritten | |
| file.file.seek(0) | |
| buffer.write(file.file.read()) | |
| response = {"file": save_path} | |
| return utils.get_response(200, response) | |
| raise HttpException( | |
| "", status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded" | |
| ) | |
| async def stream_video(request: Request, file_path: str): | |
| tasks_dir = utils.task_dir() | |
| video_path = os.path.join(tasks_dir, file_path) | |
| range_header = request.headers.get("Range") | |
| video_size = os.path.getsize(video_path) | |
| start, end = 0, video_size - 1 | |
| length = video_size | |
| if range_header: | |
| range_ = range_header.split("bytes=")[1] | |
| start, end = [int(part) if part else None for part in range_.split("-")] | |
| if start is None: | |
| start = video_size - end | |
| end = video_size - 1 | |
| if end is None: | |
| end = video_size - 1 | |
| length = end - start + 1 | |
| def file_iterator(file_path, offset=0, bytes_to_read=None): | |
| with open(file_path, "rb") as f: | |
| f.seek(offset, os.SEEK_SET) | |
| remaining = bytes_to_read or video_size | |
| while remaining > 0: | |
| bytes_to_read = min(4096, remaining) | |
| data = f.read(bytes_to_read) | |
| if not data: | |
| break | |
| remaining -= len(data) | |
| yield data | |
| response = StreamingResponse( | |
| file_iterator(video_path, start, length), media_type="video/mp4" | |
| ) | |
| response.headers["Content-Range"] = f"bytes {start}-{end}/{video_size}" | |
| response.headers["Accept-Ranges"] = "bytes" | |
| response.headers["Content-Length"] = str(length) | |
| response.status_code = 206 # Partial Content | |
| return response | |
| async def download_video(_: Request, file_path: str): | |
| """ | |
| download video | |
| :param _: Request request | |
| :param file_path: video file path, eg: /cd1727ed-3473-42a2-a7da-4faafafec72b/final-1.mp4 | |
| :return: video file | |
| """ | |
| tasks_dir = utils.task_dir() | |
| video_path = os.path.join(tasks_dir, file_path) | |
| file_path = pathlib.Path(video_path) | |
| filename = file_path.stem | |
| extension = file_path.suffix | |
| headers = {"Content-Disposition": f"attachment; filename={filename}{extension}"} | |
| return FileResponse( | |
| path=video_path, | |
| headers=headers, | |
| filename=f"{filename}{extension}", | |
| media_type=f"video/{extension[1:]}", | |
| ) | |