Scamzero-ml-api / api.py
YOGEXH's picture
Upload 5 files
486f81d verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AlbertForSequenceClassification, AlbertTokenizer
import torch
app = FastAPI()
MODEL_PATH = "./saved_model"
tokenizer = AlbertTokenizer.from_pretrained(MODEL_PATH)
model = AlbertForSequenceClassification.from_pretrained(MODEL_PATH)
model.eval()
class JobInput(BaseModel):
description: str
@app.get("/")
def read_root():
return {
"message": "ScamZero Prediction API is running!",
"endpoints": {
"predict": "/predict (POST)",
"health": "/health (GET)"
}
}
@app.get("/health")
def health_check():
return {"status": "healthy", "model": "albert-base-v2"}
@app.post("/predict")
def predict(data: JobInput):
text = data.description
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
prediction = torch.argmax(probs, dim=-1).item()
confidence = probs[0][prediction].item()
return {
"prediction": prediction,
"label": "Scam" if prediction == 1 else "Safe",
"confidence": confidence
}