aryanprooo's picture
Update app.py
b663a21 verified
# app.py - Fashion Classification with PRETRAINED Color Detection
import gradio as gr
from transformers import CLIPProcessor, CLIPModel, pipeline
from PIL import Image
import torch
import numpy as np
# ======================
# Model Configuration
# ======================
print("[INFO] Loading models...")
# Fashion Classification Model
FASHION_MODEL = "patrickjohncyh/fashion-clip"
fashion_model = CLIPModel.from_pretrained(FASHION_MODEL)
fashion_processor = CLIPProcessor.from_pretrained(FASHION_MODEL)
print("[SUCCESS] ✅ Fashion-CLIP loaded!")
# Color Detection Model - Using CLIP for color detection
color_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
color_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
print("[SUCCESS] ✅ Color Detection Model (CLIP) loaded!")
# ======================
# Fashion Categories
# ======================
FASHION_CATEGORIES = [
# Indian Wear
"saree", "kurta", "salwar kameez", "lehenga", "sherwani", "churidar",
"anarkali", "palazzo", "kurti", "dhoti",
# Western Wear
"dress", "shirt", "t-shirt", "trousers", "jeans", "pants", "shorts",
"skirt", "jacket", "coat", "sweater", "hoodie", "blazer", "cardigan",
# Footwear
"sneakers", "boots", "sandals", "heels", "flats", "slippers",
# Accessories
"handbag", "backpack", "hat", "scarf", "sunglasses", "watch", "belt"
]
# ======================
# Comprehensive Color List for CLIP
# ======================
COLOR_LABELS = [
# Basic Colors
"red", "blue", "green", "yellow", "orange", "purple", "pink",
"brown", "black", "white", "gray", "grey",
# Reds
"dark red", "light red", "crimson", "maroon", "burgundy",
"wine red", "cherry red", "scarlet",
# Pinks
"light pink", "hot pink", "coral", "salmon", "rose pink",
"baby pink", "magenta", "fuchsia",
# Oranges
"dark orange", "light orange", "peach", "tangerine",
"rust orange", "burnt orange",
# Yellows
"light yellow", "dark yellow", "golden yellow", "lemon yellow",
"mustard yellow", "cream yellow", "amber",
# Greens
"dark green", "light green", "forest green", "olive green",
"mint green", "lime green", "emerald green", "sage green",
"teal", "sea green",
# Blues
"dark blue", "light blue", "navy blue", "royal blue",
"sky blue", "baby blue", "turquoise", "cyan", "aqua",
"indigo", "cobalt blue", "denim blue", "steel blue",
# Purples
"dark purple", "light purple", "violet", "lavender",
"plum", "orchid", "mauve", "lilac",
# Browns
"dark brown", "light brown", "chocolate brown", "tan",
"beige", "khaki", "caramel", "coffee brown",
"taupe", "sand", "bronze",
# Grays
"light gray", "dark gray", "charcoal", "silver",
"ash gray", "slate gray", "stone gray",
# Special
"gold", "silver", "copper", "cream", "ivory", "off-white",
"wine", "burgundy", "rust", "denim", "multicolor"
]
# Prepare color prompts for better detection
COLOR_PROMPTS = [f"a {color} colored clothing item" for color in COLOR_LABELS]
# ======================
# Color Detection using CLIP (Pretrained)
# ======================
def detect_color_with_clip(image, top_k=3):
"""
Detect color using pretrained CLIP model
Returns: List of (color_name, confidence_score)
"""
try:
print("[INFO] 🎨 Detecting colors with CLIP model...")
# Prepare inputs
inputs = color_processor(
text=COLOR_PROMPTS,
images=image,
return_tensors="pt",
padding=True
)
# Get predictions
with torch.no_grad():
outputs = color_model(**inputs)
# Calculate probabilities
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)[0]
# Get top K colors
top_probs, top_indices = torch.topk(probs, k=top_k)
detected_colors = []
for prob, idx in zip(top_probs, top_indices):
color_name = COLOR_LABELS[idx.item()]
confidence = prob.item()
# Only include if confidence > 5%
if confidence > 0.05:
detected_colors.append((color_name, confidence))
print(f"[INFO] - {color_name}: {confidence:.1%}")
if not detected_colors:
return [("unknown", 0.0)]
return detected_colors
except Exception as e:
print(f"[ERROR] Color detection failed: {e}")
import traceback
traceback.print_exc()
return [("detection failed", 0.0)]
# ======================
# Alternative: Using Image Classification Pipeline
# ======================
def detect_color_with_pipeline(image):
"""
Alternative: Using HuggingFace image classification pipeline
"""
try:
# Load a pretrained color classification model
# You can replace this with a specific color detection model if available
classifier = pipeline("image-classification", model="google/vit-base-patch16-224")
results = classifier(image)
# Filter for color-related predictions
color_keywords = ['red', 'blue', 'green', 'yellow', 'orange', 'purple',
'pink', 'brown', 'black', 'white', 'gray', 'grey']
color_results = []
for result in results:
label_lower = result['label'].lower()
for color in color_keywords:
if color in label_lower:
color_results.append((color, result['score']))
break
return color_results if color_results else [("multicolor", 0.5)]
except Exception as e:
print(f"[ERROR] Pipeline color detection failed: {e}")
return [("unknown", 0.0)]
# ======================
# Clothing Type Detection
# ======================
def detect_clothing_type(category):
"""Detect if Indian or Western wear"""
indian_wear = [
'saree', 'kurta', 'salwar', 'lehenga', 'sherwani',
'churidar', 'anarkali', 'kurti', 'dhoti', 'palazzo'
]
category_lower = category.lower()
for item in indian_wear:
if item in category_lower:
return "🇮🇳 Indian Wear"
return "🌍 Western Wear"
# ======================
# Color Emoji Mapping
# ======================
def get_color_emoji(color_name):
"""Get emoji for color"""
color_lower = color_name.lower()
if 'red' in color_lower or 'crimson' in color_lower or 'scarlet' in color_lower:
return "🔴"
elif 'pink' in color_lower or 'rose' in color_lower or 'coral' in color_lower:
return "🩷"
elif 'orange' in color_lower or 'peach' in color_lower or 'rust' in color_lower:
return "🟠"
elif 'yellow' in color_lower or 'gold' in color_lower or 'amber' in color_lower:
return "🟡"
elif 'green' in color_lower or 'olive' in color_lower or 'lime' in color_lower or 'emerald' in color_lower:
return "🟢"
elif 'blue' in color_lower or 'navy' in color_lower or 'cyan' in color_lower or 'aqua' in color_lower or 'denim' in color_lower:
return "🔵"
elif 'purple' in color_lower or 'violet' in color_lower or 'lavender' in color_lower or 'plum' in color_lower:
return "🟣"
elif 'brown' in color_lower or 'tan' in color_lower or 'beige' in color_lower or 'khaki' in color_lower:
return "🟤"
elif 'black' in color_lower or 'dark' in color_lower or 'charcoal' in color_lower:
return "⚫"
elif 'white' in color_lower or 'cream' in color_lower or 'ivory' in color_lower:
return "⚪"
elif 'gray' in color_lower or 'grey' in color_lower or 'silver' in color_lower:
return "⚪"
else:
return "🎨"
# ======================
# Format Color Name
# ======================
def format_color_name(color):
"""Format color name to title case"""
return color.replace('_', ' ').title()
# ======================
# Main Prediction Function
# ======================
def predict_fashion(image, custom_categories=None):
"""
Classify fashion item + detect color using PRETRAINED models
"""
if image is None:
return "⚠️ Please upload an image first!", {}
try:
# Step 1: Prepare Categories
if custom_categories and custom_categories.strip():
categories = [cat.strip() for cat in custom_categories.split(",")]
else:
categories = FASHION_CATEGORIES
# Step 2: Fashion Item Classification
print("[INFO] 🔍 Classifying fashion item...")
inputs = fashion_processor(
text=categories,
images=image,
return_tensors="pt",
padding=True
)
with torch.no_grad():
outputs = fashion_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)[0]
top_prob_idx = probs.argmax().item()
top_category = categories[top_prob_idx]
top_confidence = probs[top_prob_idx].item()
# Step 3: Color Detection with CLIP (Pretrained)
detected_colors = detect_color_with_clip(image, top_k=3)
# Step 4: Clothing Type
clothing_type = detect_clothing_type(top_category)
# Step 5: Format Results
result = f"""
### 🎯 Fashion Item Detected
**Item:** {top_category.upper()} 👗
**Confidence:** {top_confidence:.1%}
**Type:** {clothing_type}
---
### 🎨 Color Detection (Pretrained CLIP Model)
"""
# Display detected colors
if detected_colors and detected_colors[0][0] not in ["unknown", "detection failed"]:
# Primary Color
primary_color, primary_conf = detected_colors[0]
primary_formatted = format_color_name(primary_color)
primary_emoji = get_color_emoji(primary_color)
result += f"**Primary Color:** {primary_emoji} **{primary_formatted}** ✨\n"
result += f"*Confidence: {primary_conf:.1%}*\n\n"
# Secondary Colors
if len(detected_colors) > 1:
result += "**Secondary Colors:**\n"
for i, (color, conf) in enumerate(detected_colors[1:], 1):
formatted = format_color_name(color)
emoji = get_color_emoji(color)
result += f" {i}. {emoji} **{formatted}** ({conf:.1%})\n"
result += "\n"
# Color Summary
color_names = [format_color_name(c[0]) for c in detected_colors]
result += f"**Color Summary:** {', '.join(color_names)} 🌈\n"
else:
result += f"⚠️ Color detection: {detected_colors[0][0]}\n"
result += f"""
---
### 📊 Detection Details
✅ **Fashion Model:** Fashion-CLIP (pretrained)
🎨 **Color Model:** CLIP Vision Transformer (pretrained)
🔍 **Color Database:** {len(COLOR_LABELS)} color categories
📊 **Classification Confidence:** {top_confidence:.1%}
🧠 **Method:** Zero-shot learning (no training needed)
---
### 💡 Styling Suggestions
"""
# Clothing type suggestions
if "Indian" in clothing_type:
result += """
- Perfect for traditional occasions 🪔
- Pair with ethnic jewelry
- Great for festivals and weddings
"""
else:
result += """
- Versatile for daily wear 👔
- Mix and match with other items
- Suitable for casual/formal settings
"""
# Color-specific styling tips
if detected_colors and detected_colors[0][0] not in ["unknown", "detection failed"]:
primary_color = detected_colors[0][0].lower()
result += f"\n**Styling Tips for {format_color_name(detected_colors[0][0])}:**\n"
if 'black' in primary_color or 'dark' in primary_color:
result += "- Timeless and elegant ⚫\n- Pairs with everything\n- Perfect for formal occasions\n"
elif 'white' in primary_color or 'cream' in primary_color or 'ivory' in primary_color:
result += "- Fresh and clean ⚪\n- Summer favorite\n- Easy to accessorize\n"
elif 'gray' in primary_color or 'grey' in primary_color or 'silver' in primary_color:
result += "- Sophisticated neutral ⚫\n- Professional choice\n- Modern aesthetic\n"
elif 'red' in primary_color or 'maroon' in primary_color or 'crimson' in primary_color:
result += "- Bold statement 🔴\n- Confidence booster\n- Pair with neutrals\n"
elif 'blue' in primary_color or 'navy' in primary_color or 'denim' in primary_color:
result += "- Classic choice 🔵\n- Versatile wear\n- Calming effect\n"
elif 'green' in primary_color or 'olive' in primary_color:
result += "- Natural vibe 🟢\n- Fresh look\n- Great for outdoors\n"
elif 'yellow' in primary_color or 'gold' in primary_color:
result += "- Cheerful color 🟡\n- Festive choice\n- Eye-catching\n"
elif 'pink' in primary_color or 'coral' in primary_color:
result += "- Soft and feminine 🩷\n- Romantic appeal\n- Party ready\n"
elif 'purple' in primary_color or 'violet' in primary_color:
result += "- Royal elegance 🟣\n- Unique choice\n- Sophisticated\n"
elif 'brown' in primary_color or 'tan' in primary_color or 'beige' in primary_color:
result += "- Earthy warmth 🟤\n- Natural look\n- Timeless style\n"
elif 'orange' in primary_color or 'peach' in primary_color:
result += "- Vibrant energy 🟠\n- Playful choice\n- Summer perfect\n"
# Top fashion predictions
top_probs, top_indices = torch.topk(probs, k=min(5, len(categories)))
top_predictions = {}
for prob, idx in zip(top_probs, top_indices):
category = categories[idx.item()]
top_predictions[category] = float(prob.item())
return result, top_predictions
except Exception as e:
import traceback
error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
return error_msg, {}
# ======================
# Gradio Interface
# ======================
with gr.Blocks(title="Fashion AI with Pretrained Color Detection", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 👗 AI Fashion Classifier with Pretrained Color Detection
### Using CLIP Vision Transformer for Zero-Shot Color Recognition
""")
gr.Markdown(f"""
**Fashion Model:** Fashion-CLIP (pretrained)
**Color Model:** OpenAI CLIP ViT-B/32 (pretrained)
**Color Categories:** {len(COLOR_LABELS)} colors
**Method:** Zero-shot learning (no dataset training needed)
### ✨ Why Pretrained Models?
1. 🎯 **Highly Accurate** - Trained on millions of images
2. ⚡ **Fast** - No preprocessing needed
3. 🧠 **Smart** - Understands context and variations
4. 🔄 **Generalizable** - Works on any clothing type
5. 📊 **Reliable** - Consistent results
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="📤 Upload Fashion Image")
custom_categories = gr.Textbox(
label="🏷️ Custom Categories (Optional)",
placeholder="gray shorts, blue jeans, red kurta, white shirt",
info="Comma-separated. Leave empty for 50+ default categories."
)
predict_btn = gr.Button("🔍 Analyze with AI Models", variant="primary", size="lg")
gr.Markdown("""
**💡 Tips:**
- Clear photos work best
- Good lighting recommended
- Single item preferred
**🎨 Supported Colors ({} types):**
- Basic: Red, Blue, Green, Yellow, Orange, Purple, Pink, Brown, Black, White, Gray
- Shades: Dark/Light variations
- Specific: Navy, Maroon, Teal, Lavender, Beige, etc.
**⚡ No Installation Needed:**
All models are pretrained and ready to use!
""".format(len(COLOR_LABELS)))
with gr.Column():
output_text = gr.Markdown(label="📋 AI Analysis Results")
output_label = gr.Label(label="📊 Top 5 Item Predictions", num_top_classes=5)
# Event Handler
predict_btn.click(
fn=predict_fashion,
inputs=[input_image, custom_categories],
outputs=[output_text, output_label]
)
gr.Markdown(f"""
---
### 📝 Example Test Cases
| Item | Expected Colors |
|------|----------------|
| **Gray Shorts** | Gray, Light Gray, Dark Gray, Charcoal |
| **Denim Jeans** | Denim Blue, Navy Blue, Dark Blue |
| **Red Saree** | Red, Crimson, Dark Red |
| **White Shirt** | White, Off-White, Cream |
| **Black Kurta** | Black, Dark Gray, Charcoal |
| **Beige Dress** | Beige, Tan, Light Brown, Cream |
---
### 🎨 Color Detection Technology
**Model:** OpenAI CLIP (Contrastive Language-Image Pre-training)
**How it works:**
1. Image is processed through Vision Transformer
2. Compared with {len(COLOR_LABELS)} color text descriptions
3. Returns best matching colors with confidence scores
4. No background removal needed
5. Context-aware (understands "denim blue" vs "sky blue")
**Advantages over traditional methods:**
- ✅ Pretrained on 400M+ image-text pairs
- ✅ Understands color context (e.g., "denim blue", "burgundy red")
- ✅ No manual threshold tuning needed
- ✅ Works on complex patterns and textures
- ✅ Handles shadows and lighting variations
---
**🚀 Powered by:**
- Fashion-CLIP (patrickjohncyh/fashion-clip)
- OpenAI CLIP ViT-B/32
- HuggingFace Transformers
- Zero-shot learning (no training required)
""".format(len(COLOR_LABELS)))
# ======================
# Launch
# ======================
if __name__ == "__main__":
print("\n" + "="*60)
print("🚀 FASHION AI WITH PRETRAINED COLOR DETECTION")
print("="*60)
print(f"✅ Fashion Model: Fashion-CLIP (loaded)")
print(f"✅ Color Model: CLIP ViT-B/32 (loaded)")
print(f"✅ Fashion Categories: {len(FASHION_CATEGORIES)}")
print(f"✅ Color Categories: {len(COLOR_LABELS)}")
print(f"✅ Method: Zero-shot learning")
print(f"✅ Background Removal: Not needed (AI handles it)")
print("="*60 + "\n")
demo.launch(server_name="0.0.0.0", server_port=7860)