File size: 9,879 Bytes
ee9be8b
 
 
 
e7afee2
3290944
e7afee2
ee9be8b
 
 
 
1042819
3f6d97b
 
ee9be8b
2b9f49c
e7afee2
ee9be8b
 
 
 
 
 
 
 
 
 
2b9f49c
 
 
 
 
 
e7afee2
ee9be8b
 
e7afee2
ee9be8b
 
 
 
 
 
e7afee2
ee9be8b
 
 
 
 
 
 
e7afee2
 
 
ee9be8b
 
e7afee2
 
 
 
 
 
 
 
 
 
 
ee9be8b
 
e7afee2
103ddfa
2029c90
 
e7afee2
 
ee9be8b
 
e7afee2
ee9be8b
e7afee2
 
ee9be8b
3290944
ee9be8b
 
e7afee2
d3b998c
 
 
ee9be8b
 
 
 
 
 
3290944
52f0b05
 
 
 
 
 
 
 
3290944
 
 
52f0b05
3290944
 
 
 
52f0b05
 
 
3290944
 
52f0b05
 
3290944
 
 
 
 
 
 
52f0b05
 
3290944
 
52f0b05
3290944
52f0b05
 
 
 
 
34c9348
fe72903
 
 
 
 
 
 
52f0b05
 
 
 
 
 
 
 
 
 
 
 
 
 
3290944
 
 
 
52f0b05
3290944
52f0b05
 
 
 
 
 
 
 
 
 
 
 
3290944
 
 
52f0b05
 
 
 
 
 
 
 
 
 
3290944
 
 
 
 
 
 
 
 
 
 
 
 
 
ee9be8b
3290944
ee9be8b
 
3290944
 
 
 
 
 
ee9be8b
 
 
52f0b05
 
34c9348
 
 
3290944
 
 
 
34c9348
 
3290944
 
ee9be8b
e7afee2
ee9be8b
 
52f0b05
ee9be8b
3290944
e7afee2
 
 
 
52f0b05
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
from fastapi import FastAPI, HTTPException, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from llama_cpp import Llama
from huggingface_hub import snapshot_download
import os
import time
import logging
from typing import AsyncGenerator, Optional
import asyncio
import psutil
# 模型加载后,添加CPU指令集日志
import subprocess


# 日志与FastAPI初始化
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)

app = FastAPI(title="CodeLlama-7B-Instruct (GGUF-4bit) CPU", version="1.0")
logger.info("FastAPI应用初始化完成")

# 执行命令查看CPU支持的指令集
try:
    avx2_support = subprocess.check_output("grep -c avx2 /proc/cpuinfo", shell=True).decode().strip()
    logger.info(f"CPU AVX2支持:{'是' if int(avx2_support) > 0 else '否'}")
except Exception as e:
    logger.warning(f"检测AVX2支持失败:{str(e)}")
# 跨域配置
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
logger.info("跨域中间件配置完成")

# API密钥验证
def verify_api_key(api_key: Optional[str] = Header(None, alias="api_key")):
    valid_key = os.getenv("NODE_API_KEY", "default-node-key-123")
    if not api_key or api_key != valid_key:
        logger.warning(f"无效API密钥:{api_key}")
        raise HTTPException(status_code=401, detail="Invalid or missing API Key")
    return api_key

# 自动下载GGUF模型
MODEL_REPO = "TheBloke/CodeLlama-7B-Instruct-GGUF"
MODEL_FILE = "codellama-7b-instruct.Q4_K_M.gguf"

try:
    logger.info(f"开始从Hugging Face下载模型:{MODEL_REPO}/{MODEL_FILE}")
    model_dir = snapshot_download(
        repo_id=MODEL_REPO,
        allow_patterns=[MODEL_FILE],
        local_dir="./models",
        local_dir_use_symlinks=False
    )
    model_path = os.path.join(model_dir, MODEL_FILE)
    logger.info(f"模型下载完成,保存到:{model_path}")

    # 加载GGUF模型
    model_load_start = time.time()
    llm = Llama(
        model_path=model_path,
        n_ctx=1024,
        n_threads=2,
        n_threads_batch=2,
        n_gpu_layers=0,
        verbose=False,
    )
    model_load_end = time.time()
    logger.info(f"模型加载完成!耗时 {model_load_end - model_load_start:.2f} 秒")
except Exception as e:
    logger.error(f"模型下载或加载失败:{str(e)}", exc_info=True)
    raise RuntimeError(f"Model setup failed: {str(e)}") from e

# 数据模型与流式推理
class GenerationRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 150
    temperature: float = 0.4
    top_p: float = 0.6
    repetition_penalty: float = 1.1

@app.post("/generate/code/stream")
async def generate_code_stream(
    req: GenerationRequest,
    api_key: str = Depends(verify_api_key)
) -> StreamingResponse:
    request_id = f"req_{int(time.time() * 1000)}"
    # 补充请求参数日志,方便复现问题
    logger.info(
        f"收到推理请求 [{request_id}] | "
        f"prompt前30字符:'{req.prompt[:30]}...' | "
        f"max_new_tokens:{req.max_new_tokens} | "
        f"temperature:{req.temperature} | "
        f"top_p:{req.top_p}"
    )
    
    # 构建CodeLlama指令格式(必须严格遵循,否则生成效果差)
    formatted_prompt = f"<s>[INST] {req.prompt} [/INST]"
    logger.debug(f"请求 [{request_id}] 格式化后prompt:'{formatted_prompt[:50]}...'")  # debug级日志记录格式化结果
    
    # 流式生成器(适配Spaces前端SSE接收)
    async def stream_generator() -> AsyncGenerator[str, None]:
        start_time = time.time()
        generated_total = 0  # 累计生成词数
        token_count = 0       # 累计生成token数(细粒度进度)
        
        try:
            # 调用GGUF模型流式生成(CPU优化)
            logger.info(f"请求 [{request_id}] 开始推理,等待模型返回token流")
            for token_idx, token in enumerate(llm.create_completion(
                prompt=formatted_prompt,
                max_tokens=req.max_new_tokens,
                temperature=req.temperature,
                top_p=req.top_p,
                stream=True,
                stop=["</s>"],
                echo=False
            )):
                token_count += 1  # 计数当前生成的token序号
                # 提取生成的文本片段
                text_chunk = token["choices"][0]["text"]
                
                if text_chunk:
                    # 计算累计生成词数(按空格分割,忽略空字符串)
                    current_chunk_words = len([w for w in text_chunk.split() if w.strip()])
                    generated_total += current_chunk_words
                    
                    # 记录每个有效文本片段的日志(每5个token打印一次,避免日志刷屏)
                    if token_idx % 10 == 0 or token_idx == 0:
                        cpu_percent = psutil.cpu_percent(interval=0.1)  # 整体CPU使用率
                        per_core_usage = psutil.cpu_percent(percpu=True, interval=0.1)  # 每个核心使用率
                        logger.info(
                            f"请求 [{request_id}] CPU监控 | "
                            f"整体使用率:{cpu_percent}% | "
                            f"各核心使用率:{per_core_usage[:8]}  # 显示前8核"
                        )    
                        logger.info(
                            f"请求 [{request_id}] 推理中 | "
                            f"当前token序号:{token_idx + 1} | "
                            f"累计token数:{token_count} | "
                            f"当前片段文本:'{text_chunk[:20]}...' | "
                            f"累计生成词数:{generated_total}"
                        )
                    else:
                        logger.debug(
                            f"请求 [{request_id}] 推理中 | "
                            f"当前token序号:{token_idx + 1} | "
                            f"当前片段文本:'{text_chunk[:20]}...'"
                        )
                    
                    # 按SSE格式返回(Spaces前端可直接解析)
                    yield f"data: {text_chunk}\n\n"
                    await asyncio.sleep(0.05)  # 微调延迟,避免前端接收过快
            
            # 生成完成,记录最终统计日志
            total_time = time.time() - start_time
            total_tokens = token_count  # 最终生成的总token数
            tokens_per_second = total_tokens / total_time if total_time > 0 else 0  # 计算token生成速率
            logger.info(
                f"请求 [{request_id}] 推理完成 | "
                f"总耗时:{total_time:.2f} 秒 | "
                f"生成总token数:{total_tokens} | "
                f"生成速率:{tokens_per_second:.2f} token/秒 | "
                f"累计生成词数:{generated_total} | "
                f"请求参数:max_new_tokens={req.max_new_tokens}, temperature={req.temperature}"
            )
            
            # 生成完成标记
            yield f"event: end\ndata: 生成完成!共{generated_total}个词,耗时{total_time:.2f}秒\n\n"
        
        except Exception as e:
            # 捕获推理过程中的异常,记录详细日志(含堆栈信息)
            error_time = time.time() - start_time
            logger.error(
                f"请求 [{request_id}] 推理失败 | "
                f"已耗时:{error_time:.2f} 秒 | "
                f"已生成token数:{token_count} | "
                f"已生成词数:{generated_total} | "
                f"错误原因:{str(e)}",
                exc_info=True  # 打印堆栈信息,方便定位问题
            )
            error_msg = f"生成失败:{str(e)}"
            yield f"event: error\ndata: {error_msg}\n\n"
            raise
    
    # 返回流式响应(适配Spaces HTTP服务)
    return StreamingResponse(
        stream_generator(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no"
        }
    )

# 根路径与健康检查
@app.get("/")
async def root():
    return {
        "status": "success",
        "service": "CodeLlama-7B-Instruct (GGUF-4bit) CPU",
        "message": "Spaces部署成功!调用 /generate/code/stream 接口生成代码",
        "model_info": f"模型:{MODEL_REPO},量化版本:{MODEL_FILE}"
    }

@app.get("/health")
async def health_check(api_key: str = Depends(verify_api_key)):
    # 补充健康检查时的模型状态日志
    logger.info(f"健康检查请求 | 模型状态:已加载 | 检查时间:{time.strftime('%Y-%m-%d %H:%M:%S')}")
    physical_cores = psutil.cpu_count(logical=False)  # 查看物理核心数
    logical_cores = psutil.cpu_count(logical=True)    # 查看逻辑核心数
    logger.info(f"物理核心数:{physical_cores},逻辑核心数:{logical_cores}")
    return {
        "status": "alive",
        "model_status": "loaded",
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
        "physical_cores": physical_cores,
        "logical_cores": logical_cores,
        "model_info": f"模型:{MODEL_REPO},量化版本:{MODEL_FILE}"
    }

# Spaces启动入口
if __name__ == "__main__":
    import uvicorn
    logger.info("启动Uvicorn服务 | 主机:0.0.0.0 | 端口:7860 | 工作进程数:1")
    uvicorn.run(
        app="model_space:app",
        host="0.0.0.0",
        port=7860,
        timeout_keep_alive=300,
        workers=1
    )