File size: 2,751 Bytes
94ea426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from fastapi.responses import FileResponse
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image, ImageFilter
from fastapi import FastAPI, File, UploadFile, HTTPException
import numpy as np
import uvicorn
import torch
import os

app = FastAPI()

# Load the model and processor once to avoid reloading on every request
preprocessor = AutoImageProcessor.from_pretrained("google/deeplabv3_mobilenet_v2_1.0_513")
model = AutoModelForSemanticSegmentation.from_pretrained("google/deeplabv3_mobilenet_v2_1.0_513")

# Create the directory for saving output if it doesn’t exist
os.makedirs("output", exist_ok=True)

def get_segmentation_mask(image: Image.Image) -> Image.Image:
    """Generate a binary segmentation mask with feathered edges from an input image."""
    # Step 1: Preprocess and run model inference
    inputs = preprocessor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    # Step 2: Post-process the segmentation output to get a clean binary mask
    predicted_mask = preprocessor.post_process_semantic_segmentation(outputs)[0]
    mask_np = predicted_mask.cpu().numpy().astype("uint8") * 255  # Convert to binary values (0 or 255)
    binary_mask = Image.fromarray(mask_np)

    # Step 3: Apply a slight Gaussian blur to soften the edges
    feathered_mask = binary_mask.filter(ImageFilter.GaussianBlur(1))  # Adjust blur radius as needed
    feathered_mask = feathered_mask.resize(image.size, Image.BICUBIC)

    return feathered_mask

def apply_mask_to_image(image: Image.Image, mask: Image.Image) -> Image.Image:
    """Apply the segmentation mask to the input image to create a transparent sticker."""
    image = image.convert("RGBA")  # Ensure image is in RGBA mode
    sticker = Image.new("RGBA", image.size)
    sticker.paste(image, (0, 0), mask)  # Use mask as the alpha channel
    return sticker

@app.post("/create_sticker/")
async def create_sticker(file: UploadFile = File(...)):
    """Endpoint to convert an uploaded image to a sticker."""
    if file.content_type not in ["image/png", "image/jpeg"]:
        raise HTTPException(status_code=400, detail="Invalid image format. Only PNG and JPEG are supported.")

    # Load the image
    input_image = Image.open(file.file).convert("RGB")

    # Generate segmentation mask and apply it to create a sticker
    mask = get_segmentation_mask(input_image)
    sticker = apply_mask_to_image(input_image, mask)

    # Save the output sticker
    output_path = f"output/sticker_{file.filename}"
    sticker.save(output_path, "PNG")

    return FileResponse(output_path, media_type="image/png")


# Run the app
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)