Spaces:
Sleeping
Sleeping
Alex
commited on
Commit
·
f6210c2
1
Parent(s):
7364060
switched to segformer fashion
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
| 2 |
-
from transformers import
|
| 3 |
import torch
|
| 4 |
from PIL import Image
|
| 5 |
import numpy as np
|
|
@@ -14,11 +14,11 @@ app = FastAPI()
|
|
| 14 |
logging.basicConfig(level=logging.INFO)
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
-
# Carica il modello e il processore
|
| 18 |
try:
|
| 19 |
-
logger.info("Caricamento del modello
|
| 20 |
-
model =
|
| 21 |
-
processor =
|
| 22 |
model.to("cpu") # Usa CPU per il free tier
|
| 23 |
logger.info("Modello caricato con successo.")
|
| 24 |
except Exception as e:
|
|
@@ -27,23 +27,23 @@ except Exception as e:
|
|
| 27 |
|
| 28 |
# Funzione per segmentare l'immagine
|
| 29 |
def segment_image(image: Image.Image):
|
| 30 |
-
# Prepara l'input per
|
| 31 |
logger.info("Preparazione dell'immagine per l'inferenza...")
|
| 32 |
-
inputs = processor(image, return_tensors="pt").to("cpu")
|
| 33 |
|
| 34 |
# Inferenza
|
| 35 |
logger.info("Esecuzione dell'inferenza...")
|
| 36 |
with torch.no_grad():
|
| 37 |
-
outputs = model(**inputs
|
| 38 |
-
|
|
|
|
| 39 |
# Post-processa la maschera
|
| 40 |
logger.info("Post-processing della maschera...")
|
| 41 |
-
mask =
|
| 42 |
-
|
| 43 |
-
)[0][0].cpu().numpy()
|
| 44 |
|
| 45 |
# Converti la maschera in immagine
|
| 46 |
-
mask_img = Image.fromarray((mask * 255).astype(np.uint8))
|
| 47 |
|
| 48 |
# Converti la maschera in base64 per la risposta
|
| 49 |
buffered = io.BytesIO()
|
|
@@ -51,7 +51,7 @@ def segment_image(image: Image.Image):
|
|
| 51 |
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 52 |
|
| 53 |
# Annotazioni
|
| 54 |
-
annotations = {"mask": mask.tolist(), "label": "
|
| 55 |
|
| 56 |
return mask_base64, annotations
|
| 57 |
|
|
@@ -74,7 +74,7 @@ async def segment_endpoint(file: UploadFile = File(...)):
|
|
| 74 |
logger.error(f"Errore nell'endpoint: {str(e)}")
|
| 75 |
raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
|
| 76 |
|
| 77 |
-
# Per compatibilità con Hugging Face Spaces
|
| 78 |
if __name__ == "__main__":
|
| 79 |
import uvicorn
|
| 80 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 1 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
| 2 |
+
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
|
| 3 |
import torch
|
| 4 |
from PIL import Image
|
| 5 |
import numpy as np
|
|
|
|
| 14 |
logging.basicConfig(level=logging.INFO)
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
+
# Carica il modello e il processore SegFormer
|
| 18 |
try:
|
| 19 |
+
logger.info("Caricamento del modello SegFormer...")
|
| 20 |
+
model = SegformerForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion")
|
| 21 |
+
processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion")
|
| 22 |
model.to("cpu") # Usa CPU per il free tier
|
| 23 |
logger.info("Modello caricato con successo.")
|
| 24 |
except Exception as e:
|
|
|
|
| 27 |
|
| 28 |
# Funzione per segmentare l'immagine
|
| 29 |
def segment_image(image: Image.Image):
|
| 30 |
+
# Prepara l'input per SegFormer
|
| 31 |
logger.info("Preparazione dell'immagine per l'inferenza...")
|
| 32 |
+
inputs = processor(images=image, return_tensors="pt").to("cpu")
|
| 33 |
|
| 34 |
# Inferenza
|
| 35 |
logger.info("Esecuzione dell'inferenza...")
|
| 36 |
with torch.no_grad():
|
| 37 |
+
outputs = model(**inputs)
|
| 38 |
+
logits = outputs.logits
|
| 39 |
+
|
| 40 |
# Post-processa la maschera
|
| 41 |
logger.info("Post-processing della maschera...")
|
| 42 |
+
mask = torch.argmax(logits, dim=1)[0]
|
| 43 |
+
mask = mask.cpu().numpy()
|
|
|
|
| 44 |
|
| 45 |
# Converti la maschera in immagine
|
| 46 |
+
mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8))
|
| 47 |
|
| 48 |
# Converti la maschera in base64 per la risposta
|
| 49 |
buffered = io.BytesIO()
|
|
|
|
| 51 |
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 52 |
|
| 53 |
# Annotazioni
|
| 54 |
+
annotations = {"mask": mask.tolist(), "label": "fashion"}
|
| 55 |
|
| 56 |
return mask_base64, annotations
|
| 57 |
|
|
|
|
| 74 |
logger.error(f"Errore nell'endpoint: {str(e)}")
|
| 75 |
raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
|
| 76 |
|
| 77 |
+
# Per compatibilità con Hugging Face Spaces
|
| 78 |
if __name__ == "__main__":
|
| 79 |
import uvicorn
|
| 80 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|