| """ |
| Advanced Mode 实验 API |
| |
| 专家用户分阶段训练 API 端点 |
| |
| API 列表: |
| - POST /experiments 创建实验 |
| - GET /experiments 获取实验列表 |
| - GET /experiments/{exp_id} 获取实验详情 |
| - PATCH /experiments/{exp_id} 更新实验配置 |
| - DELETE /experiments/{exp_id} 删除实验 |
| - POST /experiments/{exp_id}/stages/{stage_type} 执行阶段 |
| - GET /experiments/{exp_id}/stages 获取所有阶段状态 |
| - GET /experiments/{exp_id}/stages/{stage_type} 获取阶段详情 |
| - DELETE /experiments/{exp_id}/stages/{stage_type} 取消阶段 |
| - GET /experiments/{exp_id}/stages/{stage_type}/progress SSE 阶段进度 |
| """ |
|
|
| import json |
| from typing import Any, Dict, Optional |
|
|
| from fastapi import APIRouter, Body, Depends, HTTPException, Query |
| from fastapi.responses import StreamingResponse |
|
|
| from ....models.schemas.experiment import ( |
| ExperimentCreate, |
| ExperimentUpdate, |
| ExperimentResponse, |
| ExperimentListResponse, |
| StageStatus, |
| StageExecuteResponse, |
| StagesListResponse, |
| STAGE_DEPENDENCIES, |
| get_stage_params_class, |
| ) |
| from ....models.schemas.common import SuccessResponse, ErrorResponse |
| from ....services.experiment_service import ExperimentService |
| from ...deps import get_experiment_service |
|
|
| router = APIRouter() |
|
|
| |
| VALID_STAGE_TYPES = list(STAGE_DEPENDENCIES.keys()) |
|
|
|
|
| @router.post( |
| "", |
| response_model=ExperimentResponse, |
| summary="创建实验", |
| description=""" |
| 创建实验(专家用户)。 |
| |
| 创建实验但不立即执行,用户可以逐阶段控制训练流程。 |
| 实验创建后,所有阶段状态为 `pending`,需要手动触发执行。 |
| |
| **训练阶段**: |
| - `audio_slice`: 音频切片 |
| - `asr`: 语音识别 |
| - `text_feature`: 文本特征提取 |
| - `hubert_feature`: HuBERT 特征提取 |
| - `semantic_token`: 语义 Token 提取 |
| - `sovits_train`: SoVITS 训练 |
| - `gpt_train`: GPT 训练 |
| """, |
| ) |
| async def create_experiment( |
| request: ExperimentCreate, |
| service: ExperimentService = Depends(get_experiment_service), |
| ) -> ExperimentResponse: |
| """ |
| 创建实验 |
| """ |
| return await service.create_experiment(request) |
|
|
|
|
| @router.get( |
| "", |
| response_model=ExperimentListResponse, |
| summary="获取实验列表", |
| description="获取所有实验列表,支持按状态筛选和分页。", |
| ) |
| async def list_experiments( |
| status: Optional[str] = Query(None, description="按状态筛选"), |
| limit: int = Query(50, ge=1, le=100, description="每页数量"), |
| offset: int = Query(0, ge=0, description="偏移量"), |
| service: ExperimentService = Depends(get_experiment_service), |
| ) -> ExperimentListResponse: |
| """ |
| 获取实验列表 |
| """ |
| return await service.list_experiments(status=status, limit=limit, offset=offset) |
|
|
|
|
| @router.get( |
| "/{exp_id}", |
| response_model=ExperimentResponse, |
| summary="获取实验详情", |
| description="获取指定实验的详细信息,包括所有阶段状态。", |
| responses={ |
| 404: {"model": ErrorResponse, "description": "实验不存在"}, |
| }, |
| ) |
| async def get_experiment( |
| exp_id: str, |
| service: ExperimentService = Depends(get_experiment_service), |
| ) -> ExperimentResponse: |
| """ |
| 获取实验详情 |
| """ |
| experiment = await service.get_experiment(exp_id) |
| if not experiment: |
| raise HTTPException(status_code=404, detail="实验不存在") |
| return experiment |
|
|
|
|
| @router.patch( |
| "/{exp_id}", |
| response_model=ExperimentResponse, |
| summary="更新实验配置", |
| description="更新实验的基础配置(非阶段参数)。", |
| responses={ |
| 404: {"model": ErrorResponse, "description": "实验不存在"}, |
| }, |
| ) |
| async def update_experiment( |
| exp_id: str, |
| request: ExperimentUpdate, |
| service: ExperimentService = Depends(get_experiment_service), |
| ) -> ExperimentResponse: |
| """ |
| 更新实验配置 |
| """ |
| experiment = await service.update_experiment(exp_id, request) |
| if not experiment: |
| raise HTTPException(status_code=404, detail="实验不存在") |
| return experiment |
|
|
|
|
| @router.delete( |
| "/{exp_id}", |
| response_model=SuccessResponse, |
| summary="删除实验", |
| description="删除实验及其所有阶段数据。如果有正在运行的阶段,会先取消执行。", |
| responses={ |
| 404: {"model": ErrorResponse, "description": "实验不存在"}, |
| }, |
| ) |
| async def delete_experiment( |
| exp_id: str, |
| service: ExperimentService = Depends(get_experiment_service), |
| ) -> SuccessResponse: |
| """ |
| 删除实验 |
| """ |
| success = await service.delete_experiment(exp_id) |
| if not success: |
| raise HTTPException(status_code=404, detail="实验不存在") |
| return SuccessResponse(message="实验已删除") |
|
|
|
|
| @router.post( |
| "/{exp_id}/stages/{stage_type}", |
| response_model=StageExecuteResponse, |
| summary="执行阶段", |
| description=""" |
| 执行指定阶段。 |
| |
| **阶段依赖关系**: |
| - `audio_slice`: 无依赖 |
| - `asr`: 依赖 audio_slice |
| - `text_feature`: 依赖 asr |
| - `hubert_feature`: 依赖 audio_slice |
| - `semantic_token`: 依赖 hubert_feature |
| - `sovits_train`: 依赖 text_feature, semantic_token |
| - `gpt_train`: 依赖 text_feature, semantic_token |
| |
| 如果依赖阶段未完成,会返回 400 错误。 |
| 如果阶段已完成,会重新执行(返回 `rerun: true`)。 |
| """, |
| responses={ |
| 400: {"model": ErrorResponse, "description": "阶段类型无效或依赖未满足"}, |
| 404: {"model": ErrorResponse, "description": "实验不存在"}, |
| }, |
| ) |
| async def execute_stage( |
| exp_id: str, |
| stage_type: str, |
| params: Dict[str, Any] = Body(default={}), |
| service: ExperimentService = Depends(get_experiment_service), |
| ) -> StageExecuteResponse: |
| """ |
| 执行阶段 |
| """ |
| |
| if stage_type not in VALID_STAGE_TYPES: |
| raise HTTPException( |
| status_code=400, |
| detail=f"无效的阶段类型: {stage_type}。有效类型: {', '.join(VALID_STAGE_TYPES)}" |
| ) |
| |
| |
| experiment = await service.get_experiment(exp_id) |
| if not experiment: |
| raise HTTPException(status_code=404, detail="实验不存在") |
| |
| |
| deps = await service.check_stage_dependencies(exp_id, stage_type) |
| if not deps["satisfied"]: |
| raise HTTPException( |
| status_code=400, |
| detail=f"依赖阶段未完成: {', '.join(deps['missing'])}" |
| ) |
| |
| |
| try: |
| params_class = get_stage_params_class(stage_type) |
| validated_params = params_class(**params) |
| params = validated_params.model_dump(exclude_unset=True) |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) |
| |
| |
| result = await service.execute_stage(exp_id, stage_type, params) |
| if not result: |
| raise HTTPException(status_code=404, detail="实验不存在") |
| |
| return result |
|
|
|
|
| @router.get( |
| "/{exp_id}/stages", |
| response_model=StagesListResponse, |
| summary="获取所有阶段状态", |
| description="获取实验的所有阶段状态列表。", |
| responses={ |
| 404: {"model": ErrorResponse, "description": "实验不存在"}, |
| }, |
| ) |
| async def get_all_stages( |
| exp_id: str, |
| service: ExperimentService = Depends(get_experiment_service), |
| ) -> StagesListResponse: |
| """ |
| 获取所有阶段状态 |
| """ |
| result = await service.get_all_stages(exp_id) |
| if not result: |
| raise HTTPException(status_code=404, detail="实验不存在") |
| return result |
|
|
|
|
| @router.get( |
| "/{exp_id}/stages/{stage_type}", |
| response_model=StageStatus, |
| summary="获取阶段详情", |
| description="获取指定阶段的详细状态和结果。", |
| responses={ |
| 400: {"model": ErrorResponse, "description": "阶段类型无效"}, |
| 404: {"model": ErrorResponse, "description": "实验或阶段不存在"}, |
| }, |
| ) |
| async def get_stage( |
| exp_id: str, |
| stage_type: str, |
| service: ExperimentService = Depends(get_experiment_service), |
| ) -> StageStatus: |
| """ |
| 获取阶段详情 |
| """ |
| |
| if stage_type not in VALID_STAGE_TYPES: |
| raise HTTPException( |
| status_code=400, |
| detail=f"无效的阶段类型: {stage_type}" |
| ) |
| |
| stage = await service.get_stage(exp_id, stage_type) |
| if not stage: |
| raise HTTPException(status_code=404, detail="实验或阶段不存在") |
| return stage |
|
|
|
|
| @router.delete( |
| "/{exp_id}/stages/{stage_type}", |
| response_model=SuccessResponse, |
| summary="取消阶段", |
| description="取消正在执行的阶段。只有运行中的阶段可以取消。", |
| responses={ |
| 400: {"model": ErrorResponse, "description": "阶段未运行或无法取消"}, |
| 404: {"model": ErrorResponse, "description": "实验或阶段不存在"}, |
| }, |
| ) |
| async def cancel_stage( |
| exp_id: str, |
| stage_type: str, |
| service: ExperimentService = Depends(get_experiment_service), |
| ) -> SuccessResponse: |
| """ |
| 取消阶段 |
| """ |
| |
| if stage_type not in VALID_STAGE_TYPES: |
| raise HTTPException( |
| status_code=400, |
| detail=f"无效的阶段类型: {stage_type}" |
| ) |
| |
| success = await service.cancel_stage(exp_id, stage_type) |
| if not success: |
| raise HTTPException( |
| status_code=400, |
| detail="阶段未运行或无法取消" |
| ) |
| |
| return SuccessResponse(message=f"阶段 {stage_type} 已取消") |
|
|
|
|
| @router.get( |
| "/{exp_id}/stages/{stage_type}/progress", |
| summary="SSE 阶段进度订阅", |
| description=""" |
| 订阅阶段进度更新(Server-Sent Events)。 |
| |
| 返回的事件流格式: |
| ``` |
| event: progress |
| data: {"epoch": 8, "total_epochs": 16, "progress": 0.50, "loss": 0.034} |
| |
| event: checkpoint |
| data: {"epoch": 8, "model_path": "logs/my_voice/sovits_e8.pth"} |
| |
| event: completed |
| data: {"status": "completed", "final_loss": 0.023} |
| ``` |
| """, |
| responses={ |
| 400: {"model": ErrorResponse, "description": "阶段类型无效"}, |
| 404: {"model": ErrorResponse, "description": "实验或阶段不存在"}, |
| }, |
| ) |
| async def subscribe_stage_progress( |
| exp_id: str, |
| stage_type: str, |
| service: ExperimentService = Depends(get_experiment_service), |
| ) -> StreamingResponse: |
| """ |
| SSE 阶段进度订阅 |
| """ |
| |
| if stage_type not in VALID_STAGE_TYPES: |
| raise HTTPException( |
| status_code=400, |
| detail=f"无效的阶段类型: {stage_type}" |
| ) |
| |
| |
| experiment = await service.get_experiment(exp_id) |
| if not experiment: |
| raise HTTPException(status_code=404, detail="实验不存在") |
| |
| async def event_generator(): |
| """生成 SSE 事件流""" |
| async for progress in service.subscribe_stage_progress(exp_id, stage_type): |
| |
| event_type = progress.get("type", "progress") |
| status = progress.get("status") |
| |
| if status == "completed": |
| event_type = "completed" |
| elif status == "failed": |
| event_type = "failed" |
| elif status == "cancelled": |
| event_type = "cancelled" |
| elif progress.get("model_path"): |
| event_type = "checkpoint" |
| |
| |
| data = json.dumps(progress, ensure_ascii=False) |
| yield f"event: {event_type}\ndata: {data}\n\n" |
| |
| |
| if status in ("completed", "failed", "cancelled"): |
| break |
| |
| return StreamingResponse( |
| event_generator(), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "X-Accel-Buffering": "no", |
| }, |
| ) |
|
|