fiewolf1000 commited on
Commit
ee9be8b
·
verified ·
1 Parent(s): 06dbf35

Create model_space.py

Browse files
Files changed (1) hide show
  1. model_space.py +157 -0
model_space.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Depends, Header
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
4
+ from pydantic import BaseModel
5
+ from llama_cpp import Llama # GGUF模型加载核心库
6
+ import time
7
+ import logging
8
+ from typing import AsyncGenerator, Optional
9
+ import asyncio
10
+
11
+ # ---------------------- 1. 日志与FastAPI初始化(适配Spaces) ----------------------
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format="%(asctime)s - %(levelname)s - %(message)s",
15
+ datefmt="%Y-%m-%d %H:%M:%S"
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+ app = FastAPI(title="CodeLlama-7B-Instruct (GGUF-4bit) CPU", version="1.0")
20
+ logger.info("FastAPI应用初始化完成")
21
+
22
+ # 跨域配置(Spaces前端调用需开启)
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"], # Spaces前端域名自动适配,无需修改
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+ logger.info("跨域中间件配置完成")
31
+
32
+ # ---------------------- 2. API密钥验证(保留原逻辑,适配Spaces安全) ----------------------
33
+ def verify_api_key(api_key: Optional[str] = Header(None, alias="api_key")):
34
+ # Spaces可通过环境变量传密钥,避免硬编码(在Spaces设置中配置NODE_API_KEY)
35
+ valid_key = os.getenv("NODE_API_KEY", "default-node-key-123")
36
+ if not api_key or api_key != valid_key:
37
+ logger.warning(f"无效API密钥:{api_key}")
38
+ raise HTTPException(status_code=401, detail="Invalid or missing API Key")
39
+ return api_key
40
+
41
+ # ---------------------- 3. 加载GGUF模型(关键:4-bit量化,CPU优化) ----------------------
42
+ # 模型路径:直接用Hugging Face模型ID(Spaces会自动下载,无需手动传模型文件)
43
+ MODEL_ID = "TheBloke/CodeLlama-7B-Instruct-GGUF"
44
+ MODEL_FILE = "codellama-7b-instruct.Q4_K_M.gguf" # 4-bit量化文件(内存占用最小且效果好)
45
+
46
+ try:
47
+ model_load_start = time.time()
48
+ logger.info(f"开始加载模型:{MODEL_ID}/{MODEL_FILE}(CPU环境)")
49
+
50
+ # 核心:GGUF模型加载配置(CPU专用,限制内存占用)
51
+ llm = Llama(
52
+ model_path=f"models/{MODEL_ID}/{MODEL_FILE}", # Spaces自动缓存路径
53
+ n_ctx=2048, # 上下文长度(2048足够代码生成,太大占内存)
54
+ n_threads=4, # 线程数(Spaces CPU多为4核,设为4最优)
55
+ n_threads_batch=4, # 批处理线程数,与n_threads一致
56
+ n_gpu_layers=0, # 0=纯CPU(Spaces免费版无GPU)
57
+ verbose=False, # 关闭冗余日志,减少Spaces输出占用
58
+ )
59
+
60
+ model_load_end = time.time()
61
+ logger.info(f"模型加载完成!耗时 {model_load_end - model_load_start:.2f} 秒,内存占用约3.5GB")
62
+ except Exception as e:
63
+ logger.error(f"模型加载失败:{str(e)}", exc_info=True)
64
+ raise RuntimeError(f"Model load failed: {str(e)}") from e
65
+
66
+ # ---------------------- 4. 数据模型与流式推理(适配Spaces前端接收) ----------------------
67
+ class GenerationRequest(BaseModel):
68
+ prompt: str
69
+ max_new_tokens: int = 150 # 限制生成长度(CPU环境150足够,太长耗时)
70
+ temperature: float = 0.6 # 代码生成推荐0.5-0.7,兼顾多样性和准确性
71
+ top_p: float = 0.9
72
+
73
+ @app.post("/generate/code/stream")
74
+ async def generate_code_stream(
75
+ req: GenerationRequest,
76
+ api_key: str = Depends(verify_api_key)
77
+ ) -> StreamingResponse:
78
+ request_id = f"req_{int(time.time() * 1000)}"
79
+ logger.info(f"收到请求 {request_id}:prompt='{req.prompt[:30]}...'")
80
+
81
+ # 构建CodeLlama指令格式(必须严格遵循,否则生成效果差)
82
+ formatted_prompt = f"<s>[INST] {req.prompt} [/INST]"
83
+
84
+ # 流式生成器(适配Spaces前端SSE接收)
85
+ async def stream_generator() -> AsyncGenerator[str, None]:
86
+ start_time = time.time()
87
+ generated_total = 0
88
+ try:
89
+ # 调用GGUF模型流式生成(CPU优化)
90
+ for token in llm.create_completion(
91
+ prompt=formatted_prompt,
92
+ max_tokens=req.max_new_tokens,
93
+ temperature=req.temperature,
94
+ top_p=req.top_p,
95
+ stream=True, # 开启流式
96
+ stop=["</s>"], # 结束符(避免多余输出)
97
+ echo=False # 不返回输入prompt
98
+ ):
99
+ # 提取生成的文本片段
100
+ text_chunk = token["choices"][0]["text"]
101
+ if text_chunk:
102
+ generated_total += len(text_chunk.split())
103
+ # 按SSE格式返回(Spaces前端可直接解析)
104
+ yield f"data: {text_chunk}\n\n"
105
+ await asyncio.sleep(0.05) # 微调延迟,避免前端接收过快
106
+
107
+ # 生成完成标记
108
+ total_time = time.time() - start_time
109
+ yield f"event: end\ndata: 生成完成!共{generated_total}个词,耗时{total_time:.2f}秒\n\n"
110
+ logger.info(f"请求 {request_id} 完成,耗时 {total_time:.2f} 秒")
111
+
112
+ except Exception as e:
113
+ error_msg = f"生成失败:{str(e)}"
114
+ logger.error(f"请求 {request_id} 错误:{error_msg}")
115
+ yield f"event: error\ndata: {error_msg}\n\n"
116
+ raise
117
+
118
+ # 返回流式响应(适配Spaces HTTP服务)
119
+ return StreamingResponse(
120
+ stream_generator(),
121
+ media_type="text/event-stream",
122
+ headers={
123
+ "Cache-Control": "no-cache", # 禁用Spaces缓存,确保实时性
124
+ "Connection": "keep-alive",
125
+ "X-Accel-Buffering": "no" # 禁止代理缓冲,避免流式断连
126
+ }
127
+ )
128
+
129
+ # ---------------------- 5. 根路径与健康检查(Spaces部署验证) ----------------------
130
+ @app.get("/")
131
+ async def root():
132
+ return {
133
+ "status": "success",
134
+ "service": "CodeLlama-7B-Instruct (GGUF-4bit) CPU",
135
+ "message": "Spaces部署成功!调用 /generate/code/stream 接口生成代码",
136
+ "model_info": f"模型:{MODEL_ID},量化:4-bit,内存占用:~3.5GB"
137
+ }
138
+
139
+ @app.get("/health")
140
+ async def health_check(api_key: str = Depends(verify_api_key)):
141
+ return {
142
+ "status": "alive",
143
+ "model_status": "loaded",
144
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
145
+ }
146
+
147
+ # ---------------------- 6. Spaces启动入口(必须用uvicorn,适配Spaces进程管理) ----------------------
148
+ if __name__ == "__main__":
149
+ import uvicorn
150
+ logger.info("启动Uvicorn服务(适配Spaces CPU)")
151
+ uvicorn.run(
152
+ app="main:app", # 文件名若为model_space.py,需改为"model_space:app"
153
+ host="0.0.0.0", # Spaces要求绑定0.0.0.0
154
+ port=7860, # Spaces默认端口
155
+ timeout_keep_alive=300, # 长连接超时(适配流式生成)
156
+ workers=1 # CPU环境1个worker足够,多worker占内存
157
+ )