aryanprooo commited on
Commit
b663a21
Β·
verified Β·
1 Parent(s): cb98570

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +292 -372
app.py CHANGED
@@ -1,34 +1,29 @@
1
- # app.py - Fashion Classification with Fashion-CLIP + Smart Color Detection
2
 
3
  import gradio as gr
4
- from transformers import CLIPProcessor, CLIPModel
5
  from PIL import Image
6
  import torch
7
  import numpy as np
8
- import cv2
9
- from sklearn.cluster import KMeans
10
 
11
  # ======================
12
- # Fashion-CLIP Model Configuration
13
  # ======================
14
- MODEL_NAME = "patrickjohncyh/fashion-clip"
15
 
16
- print("[INFO] Loading Fashion-CLIP model...")
17
- model = CLIPModel.from_pretrained(MODEL_NAME)
18
- processor = CLIPProcessor.from_pretrained(MODEL_NAME)
19
- print("[SUCCESS] Fashion-CLIP model loaded!")
 
20
 
21
- # Try to import background removal library
22
- try:
23
- from rembg import remove
24
- REMBG_AVAILABLE = True
25
- print("[SUCCESS] Background removal (rembg) available!")
26
- except ImportError:
27
- REMBG_AVAILABLE = False
28
- print("[WARNING] rembg not available. Install with: pip install rembg")
29
 
30
  # ======================
31
- # Define Fashion Categories
32
  # ======================
33
  FASHION_CATEGORIES = [
34
  # Indian Wear
@@ -47,243 +42,152 @@ FASHION_CATEGORIES = [
47
  ]
48
 
49
  # ======================
50
- # Background Removal
51
  # ======================
52
- def remove_background(image):
53
- """
54
- Remove background from image - clothing only
55
- """
56
- if REMBG_AVAILABLE:
57
- try:
58
- # Use rembg for high-quality background removal
59
- output = remove(image)
60
- print("[INFO] βœ… Background removed using rembg")
61
- return output
62
- except Exception as e:
63
- print(f"[WARNING] rembg failed: {e}")
64
-
65
- # Fallback: Use GrabCut
66
- try:
67
- img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
68
- mask = np.zeros(img_cv.shape[:2], np.uint8)
69
-
70
- bgd_model = np.zeros((1, 65), np.float64)
71
- fgd_model = np.zeros((1, 65), np.float64)
72
-
73
- height, width = img_cv.shape[:2]
74
- margin_h = int(height * 0.05)
75
- margin_w = int(width * 0.05)
76
- rect = (margin_w, margin_h, width - 2*margin_w, height - 2*margin_h)
77
-
78
- cv2.grabCut(img_cv, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT)
79
-
80
- mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
81
- img_cv = img_cv * mask2[:, :, np.newaxis]
82
-
83
- img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
84
- alpha = (mask2 * 255).astype(np.uint8)
85
- img_rgba = np.dstack((img_rgb, alpha))
86
-
87
- print("[INFO] βœ… Background removed using GrabCut")
88
- return Image.fromarray(img_rgba, 'RGBA')
89
-
90
- except Exception as e:
91
- print(f"[ERROR] Background removal failed: {e}")
92
- return image
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
  # ======================
96
- # Dominant Color Detection using K-Means
97
  # ======================
98
- def get_dominant_color(image, n_colors=3):
99
  """
100
- Get THE most dominant colors from clothing using K-means clustering
101
 
102
- Returns: List of color names and their RGB values
103
  """
104
  try:
105
- # Step 1: Remove background
106
- print("[INFO] πŸ”„ Removing background...")
107
- img_no_bg = remove_background(image)
108
-
109
- # Step 2: Convert to numpy array
110
- img_array = np.array(img_no_bg)
111
-
112
- # Step 3: Extract only non-transparent pixels (clothing only)
113
- if img_array.shape[-1] == 4: # RGBA
114
- alpha = img_array[:, :, 3]
115
- mask = alpha > 100
116
- pixels = img_array[:, :, :3][mask]
117
- else: # RGB
118
- pixels = img_array.reshape(-1, 3)
119
-
120
- # Step 4: Filter out pure white and pure black
121
- pixels = pixels[~((pixels[:, 0] > 240) & (pixels[:, 1] > 240) & (pixels[:, 2] > 240))]
122
- pixels = pixels[~((pixels[:, 0] < 15) & (pixels[:, 1] < 15) & (pixels[:, 2] < 15))]
123
-
124
- if len(pixels) < 10:
125
- return ["Unable to detect"], None
126
 
127
- print(f"[INFO] πŸ“Š Analyzing {len(pixels)} clothing pixels...")
128
-
129
- # Step 5: Sample pixels if too many
130
- if len(pixels) > 5000:
131
- indices = np.random.choice(len(pixels), 5000, replace=False)
132
- pixels = pixels[indices]
133
-
134
- # Step 6: K-Means clustering
135
- kmeans = KMeans(n_clusters=min(n_colors, len(pixels)), random_state=42, n_init=10)
136
- kmeans.fit(pixels)
137
 
138
- # Step 7: Get cluster centers
139
- dominant_colors_rgb = kmeans.cluster_centers_
 
140
 
141
- # Step 8: Sort by frequency
142
- labels = kmeans.labels_
143
- label_counts = np.bincount(labels)
144
- sorted_indices = np.argsort(label_counts)[::-1]
145
 
146
- # Step 9: Convert to color names
147
- color_names = []
148
- color_rgb_values = []
149
 
150
- for idx in sorted_indices[:n_colors]:
151
- rgb = dominant_colors_rgb[idx].astype(int)
152
- color_name = rgb_to_color_name(tuple(rgb))
 
153
 
154
- if color_name not in color_names:
155
- color_names.append(color_name)
156
- color_rgb_values.append(rgb)
 
157
 
158
- print(f"[INFO] 🎨 Dominant colors detected: {', '.join(color_names)}")
 
159
 
160
- return color_names, color_rgb_values
161
 
162
  except Exception as e:
163
  print(f"[ERROR] Color detection failed: {e}")
164
  import traceback
165
  traceback.print_exc()
166
- return ["Detection Failed"], None
167
 
168
 
169
- def rgb_to_color_name(rgb):
 
 
 
170
  """
171
- Convert RGB to human-readable color name
172
  """
173
- r, g, b = rgb
174
-
175
- # White
176
- if r > 220 and g > 220 and b > 220:
177
- return "White"
178
-
179
- # Black
180
- if r < 40 and g < 40 and b < 40:
181
- return "Black"
182
-
183
- # Gray
184
- if abs(r - g) < 30 and abs(g - b) < 30 and 40 <= r <= 220:
185
- if r > 160:
186
- return "Light Gray"
187
- elif r > 100:
188
- return "Gray"
189
- else:
190
- return "Dark Gray"
191
-
192
- # Red family
193
- if r > max(g, b) + 30:
194
- if r > 200 and g < 100:
195
- return "Red"
196
- elif r > 150 and g > 50 and g < 150:
197
- if b < 80:
198
- return "Orange"
199
- else:
200
- return "Coral"
201
- elif r > 100 and g < 80 and b < 80:
202
- return "Maroon"
203
- elif r > 180 and b > 100:
204
- return "Pink"
205
-
206
- # Orange
207
- if r > 200 and 100 < g < 180 and b < 100:
208
- return "Orange"
209
-
210
- # Yellow
211
- if r > 200 and g > 200 and b < 150:
212
- if b < 80:
213
- return "Yellow"
214
- else:
215
- return "Light Yellow"
216
-
217
- # Green family
218
- if g > max(r, b) + 30:
219
- if g > 200 and r > 150:
220
- return "Light Green"
221
- elif g > 150 and r < 100 and b < 100:
222
- return "Green"
223
- elif g > 100 and r > 80 and b < 80:
224
- return "Olive"
225
- elif r < 100 and b < 100:
226
- return "Dark Green"
227
-
228
- # Blue family
229
- if b > max(r, g) + 30:
230
- if b > 200 and r < 100 and g < 100:
231
- return "Blue"
232
- elif b > 150 and r < 80 and g < 80:
233
- return "Navy"
234
- elif b > 150 and g > 100:
235
- return "Sky Blue"
236
- elif b > 100 and r > 80 and g < 100:
237
- return "Purple"
238
-
239
- # Cyan/Turquoise
240
- if g > 150 and b > 150 and r < 100:
241
- return "Cyan"
242
-
243
- # Purple/Violet
244
- if r > 100 and b > 100 and g < 100:
245
- if r > 150 and b > 150:
246
- return "Purple"
247
- elif r > b:
248
- return "Magenta"
249
- else:
250
- return "Violet"
251
-
252
- # Pink
253
- if r > 180 and b > 120 and g < 150:
254
- return "Pink"
255
-
256
- # Brown
257
- if 80 < r < 180 and 40 < g < 140 and b < 100:
258
- if r > 140:
259
- return "Brown"
260
- else:
261
- return "Dark Brown"
262
-
263
- # Beige/Tan
264
- if r > 180 and g > 150 and b > 100 and r > b:
265
- return "Beige"
266
-
267
- # Gold
268
- if r > 200 and 140 < g < 200 and b < 100:
269
- return "Gold"
270
-
271
- # Cream
272
- if r > 220 and g > 200 and 150 < b < 200:
273
- return "Cream"
274
-
275
- # Lavender
276
- if r > 180 and g > 160 and b > 200:
277
- return "Lavender"
278
-
279
- # Default
280
- return "Multicolor"
281
 
282
 
 
 
 
283
  def detect_clothing_type(category):
284
- """
285
- Detect if Indian or Western wear
286
- """
287
  indian_wear = [
288
  'saree', 'kurta', 'salwar', 'lehenga', 'sherwani',
289
  'churidar', 'anarkali', 'kurti', 'dhoti', 'palazzo'
@@ -298,66 +202,67 @@ def detect_clothing_type(category):
298
  return "🌍 Western Wear"
299
 
300
 
 
 
 
301
  def get_color_emoji(color_name):
302
- """
303
- Get emoji for color (visual representation)
304
- """
305
- color_emojis = {
306
- "Red": "πŸ”΄",
307
- "Blue": "πŸ”΅",
308
- "Green": "🟒",
309
- "Yellow": "🟑",
310
- "Orange": "🟠",
311
- "Purple": "🟣",
312
- "Pink": "🩷",
313
- "Brown": "🟀",
314
- "Black": "⚫",
315
- "White": "βšͺ",
316
- "Gray": "⚫",
317
- "Light Gray": "βšͺ",
318
- "Dark Gray": "⚫",
319
- "Navy": "πŸ”΅",
320
- "Sky Blue": "πŸ”΅",
321
- "Light Green": "🟒",
322
- "Dark Green": "🟒",
323
- "Maroon": "πŸ”΄",
324
- "Coral": "🩷",
325
- "Gold": "🟑",
326
- "Beige": "🟀",
327
- "Cream": "βšͺ",
328
- "Olive": "🟒",
329
- "Cyan": "πŸ”΅",
330
- "Magenta": "🟣",
331
- "Violet": "🟣",
332
- "Lavender": "🟣",
333
- "Light Yellow": "🟑",
334
- "Dark Brown": "🟀",
335
- "Multicolor": "🌈"
336
- }
337
-
338
- return color_emojis.get(color_name, "🎨")
339
 
340
 
341
  # ======================
342
- # Prediction Function
343
  # ======================
344
  def predict_fashion(image, custom_categories=None):
345
  """
346
- Classify fashion item + detect dominant color (TEXT OUTPUT)
347
  """
348
  if image is None:
349
  return "⚠️ Please upload an image first!", {}
350
 
351
  try:
352
- # Step 1: Categories
353
  if custom_categories and custom_categories.strip():
354
  categories = [cat.strip() for cat in custom_categories.split(",")]
355
  else:
356
  categories = FASHION_CATEGORIES
357
 
358
- # Step 2: Fashion Classification
359
  print("[INFO] πŸ” Classifying fashion item...")
360
- inputs = processor(
361
  text=categories,
362
  images=image,
363
  return_tensors="pt",
@@ -365,7 +270,7 @@ def predict_fashion(image, custom_categories=None):
365
  )
366
 
367
  with torch.no_grad():
368
- outputs = model(**inputs)
369
 
370
  logits_per_image = outputs.logits_per_image
371
  probs = logits_per_image.softmax(dim=1)[0]
@@ -374,14 +279,13 @@ def predict_fashion(image, custom_categories=None):
374
  top_category = categories[top_prob_idx]
375
  top_confidence = probs[top_prob_idx].item()
376
 
377
- # Step 3: Dominant Color Detection
378
- print("[INFO] 🎨 Detecting dominant colors...")
379
- dominant_colors, rgb_values = get_dominant_color(image, n_colors=3)
380
 
381
- # Step 4: Type Detection
382
  clothing_type = detect_clothing_type(top_category)
383
 
384
- # Step 5: Format result with TEXT COLORS
385
  result = f"""
386
  ### 🎯 Fashion Item Detected
387
 
@@ -391,48 +295,52 @@ def predict_fashion(image, custom_categories=None):
391
 
392
  ---
393
 
394
- ### 🎨 Detected Colors (Text Format)
395
 
396
  """
397
 
398
- # βœ… DISPLAY COLORS AS TEXT
399
- if dominant_colors and dominant_colors[0] != "Unable to detect":
400
  # Primary Color
401
- primary_color = dominant_colors[0]
 
402
  primary_emoji = get_color_emoji(primary_color)
403
- primary_rgb = rgb_values[0] if rgb_values else [0, 0, 0]
404
 
405
- result += f"**Primary Color:** {primary_emoji} **{primary_color}** ✨\n"
406
- result += f"*RGB Values: ({primary_rgb[0]}, {primary_rgb[1]}, {primary_rgb[2]})*\n\n"
407
 
408
  # Secondary Colors
409
- if len(dominant_colors) > 1:
410
  result += "**Secondary Colors:**\n"
411
- for i, (color, rgb) in enumerate(zip(dominant_colors[1:], rgb_values[1:]), 1):
 
412
  emoji = get_color_emoji(color)
413
- result += f" {i}. {emoji} **{color}** - RGB({rgb[0]}, {rgb[1]}, {rgb[2]})\n"
414
  result += "\n"
415
 
416
  # Color Summary
417
- result += f"**Color Summary:** {', '.join(dominant_colors)} 🌈\n"
 
418
  else:
419
- result += "⚠️ Unable to detect colors from image\n"
420
 
421
  result += f"""
422
  ---
423
 
424
  ### πŸ“Š Detection Details
425
 
426
- βœ… Background removed automatically
427
- 🎨 K-means clustering used for color detection
428
- πŸ” Classification confidence: **{top_confidence:.1%}**
429
- πŸ“Έ Analyzed against **{len(categories)}** categories
 
430
 
431
  ---
432
 
433
  ### πŸ’‘ Styling Suggestions
434
  """
435
 
 
436
  if "Indian" in clothing_type:
437
  result += """
438
  - Perfect for traditional occasions πŸͺ”
@@ -446,34 +354,36 @@ def predict_fashion(image, custom_categories=None):
446
  - Suitable for casual/formal settings
447
  """
448
 
449
- # Color-based suggestions
450
- if dominant_colors[0] != "Unable to detect":
451
- primary = dominant_colors[0].lower()
452
 
453
- result += f"\n**Color Styling Tips for {dominant_colors[0]}:**\n"
454
 
455
- if "black" in primary:
456
- result += "- Classic and versatile ⚫\n- Pairs well with any color\n- Perfect for formal events\n"
457
- elif "white" in primary:
458
- result += "- Clean and fresh look βšͺ\n- Great for summer\n- Easy to accessorize\n"
459
- elif "red" in primary:
460
- result += "- Bold and confident πŸ”΄\n- Statement piece\n- Pair with neutral colors\n"
461
- elif "blue" in primary:
462
- result += "- Cool and calming πŸ”΅\n- Professional look\n- Versatile for day/night\n"
463
- elif "green" in primary:
464
- result += "- Natural and refreshing 🟒\n- Great for outdoor events\n- Pairs with earth tones\n"
465
- elif "yellow" in primary or "gold" in primary:
466
- result += "- Bright and cheerful 🟑\n- Perfect for festivities\n- Eye-catching choice\n"
467
- elif "pink" in primary:
468
- result += "- Soft and feminine 🩷\n- Romantic vibe\n- Great for parties\n"
469
- elif "purple" in primary or "violet" in primary:
470
- result += "- Royal and elegant 🟣\n- Sophisticated choice\n- Unique statement\n"
471
- elif "brown" in primary or "beige" in primary:
472
- result += "- Earthy and warm 🟀\n- Natural aesthetic\n- Timeless appeal\n"
473
- elif "gray" in primary:
474
- result += "- Neutral and modern ⚫\n- Professional look\n- Easy to style\n"
475
-
476
- # Top predictions
 
 
477
  top_probs, top_indices = torch.topk(probs, k=min(5, len(categories)))
478
  top_predictions = {}
479
  for prob, idx in zip(top_probs, top_indices):
@@ -491,27 +401,25 @@ def predict_fashion(image, custom_categories=None):
491
  # ======================
492
  # Gradio Interface
493
  # ======================
494
- with gr.Blocks(title="Fashion-CLIP + Color Detection", theme=gr.themes.Soft()) as demo:
495
 
496
  gr.Markdown("""
497
- # πŸ‘— AI Fashion Classifier with Smart Color Detection
498
- ### Background Removal + K-Means Clustering for Accurate Colors (Text Output)
499
  """)
500
 
501
- status = "🟒 High Quality (rembg)" if REMBG_AVAILABLE else "🟑 Standard (GrabCut)"
502
-
503
  gr.Markdown(f"""
504
- **Model:** Fashion-CLIP
505
- **Color Detection:** K-Means Clustering
506
- **Background Removal:** {status}
507
- **Output Format:** Text-based color names with RGB values
508
-
509
- ### ✨ How it works:
510
- 1. πŸ–ΌοΈ **Upload** your fashion item image
511
- 2. πŸ”„ **Background** is automatically removed
512
- 3. 🎨 **K-means** finds the dominant colors from clothing only
513
- 4. πŸ‘— **Fashion-CLIP** classifies the item type
514
- 5. πŸ“ **Colors displayed as text** with emojis and RGB values
515
  """)
516
 
517
  with gr.Row():
@@ -520,81 +428,93 @@ with gr.Blocks(title="Fashion-CLIP + Color Detection", theme=gr.themes.Soft()) a
520
 
521
  custom_categories = gr.Textbox(
522
  label="🏷️ Custom Categories (Optional)",
523
- placeholder="red saree, blue kurta, black jeans, white shirt",
524
  info="Comma-separated. Leave empty for 50+ default categories."
525
  )
526
 
527
- predict_btn = gr.Button("πŸ” Analyze Fashion Item", variant="primary", size="lg")
528
 
529
  gr.Markdown("""
530
  **πŸ’‘ Tips:**
531
- - Clear, well-lit photos work best
532
- - Single clothing item preferred
533
- - Background will be auto-removed
534
- - Colors shown as text with RGB values
535
 
536
- **πŸ“¦ For Best Results:**
537
- ```bash
538
- pip install rembg scikit-learn
539
- ```
540
- """)
 
 
 
541
 
542
  with gr.Column():
543
- output_text = gr.Markdown(label="πŸ“‹ Results")
544
- output_label = gr.Label(label="πŸ“Š Top 5 Predictions", num_top_classes=5)
545
 
546
- # Event
547
  predict_btn.click(
548
  fn=predict_fashion,
549
  inputs=[input_image, custom_categories],
550
  outputs=[output_text, output_label]
551
  )
552
 
553
- gr.Markdown("""
554
  ---
555
- ### πŸ“ Example Queries
556
 
557
- | Type | Example Categories |
558
- |------|-------------------|
559
- | **Indian** | `red saree, green lehenga, blue kurta` |
560
- | **Western** | `black jeans, white shirt, blue dress` |
561
- | **Formal** | `black suit, white shirt, navy blazer` |
562
- | **Footwear** | `white sneakers, black boots, brown sandals` |
 
 
563
 
564
  ---
565
 
566
- ### 🎨 Color Output Format:
567
 
568
- **Primary Color:** πŸ”΄ **Red** ✨
569
- *RGB Values: (255, 0, 0)*
570
 
571
- **Secondary Colors:**
572
- 1. βšͺ **White** - RGB(255, 255, 255)
573
- 2. ⚫ **Black** - RGB(0, 0, 0)
 
 
 
574
 
575
- **Color Summary:** Red, White, Black 🌈
 
 
 
 
 
576
 
577
  ---
578
 
579
  **πŸš€ Powered by:**
580
  - Fashion-CLIP (patrickjohncyh/fashion-clip)
581
- - K-Means Clustering for dominant colors
582
- - rembg/GrabCut for background removal
583
- - Text-based color output with emojis
584
- """)
585
 
586
  # ======================
587
  # Launch
588
  # ======================
589
  if __name__ == "__main__":
590
  print("\n" + "="*60)
591
- print("πŸš€ FASHION CLASSIFIER WITH SMART COLOR DETECTION (TEXT)")
592
  print("="*60)
593
- print(f"βœ… Fashion-CLIP Model: Loaded")
594
- print(f"βœ… Categories: {len(FASHION_CATEGORIES)}")
595
- print(f"βœ… Background Removal: {'rembg (High Quality)' if REMBG_AVAILABLE else 'GrabCut (Standard)'}")
596
- print(f"βœ… Color Detection: K-Means Clustering")
597
- print(f"βœ… Output Format: Text with RGB values")
 
598
  print("="*60 + "\n")
599
 
600
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # app.py - Fashion Classification with PRETRAINED Color Detection
2
 
3
  import gradio as gr
4
+ from transformers import CLIPProcessor, CLIPModel, pipeline
5
  from PIL import Image
6
  import torch
7
  import numpy as np
 
 
8
 
9
  # ======================
10
+ # Model Configuration
11
  # ======================
12
+ print("[INFO] Loading models...")
13
 
14
+ # Fashion Classification Model
15
+ FASHION_MODEL = "patrickjohncyh/fashion-clip"
16
+ fashion_model = CLIPModel.from_pretrained(FASHION_MODEL)
17
+ fashion_processor = CLIPProcessor.from_pretrained(FASHION_MODEL)
18
+ print("[SUCCESS] βœ… Fashion-CLIP loaded!")
19
 
20
+ # Color Detection Model - Using CLIP for color detection
21
+ color_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
22
+ color_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
23
+ print("[SUCCESS] βœ… Color Detection Model (CLIP) loaded!")
 
 
 
 
24
 
25
  # ======================
26
+ # Fashion Categories
27
  # ======================
28
  FASHION_CATEGORIES = [
29
  # Indian Wear
 
42
  ]
43
 
44
  # ======================
45
+ # Comprehensive Color List for CLIP
46
  # ======================
47
+ COLOR_LABELS = [
48
+ # Basic Colors
49
+ "red", "blue", "green", "yellow", "orange", "purple", "pink",
50
+ "brown", "black", "white", "gray", "grey",
51
+
52
+ # Reds
53
+ "dark red", "light red", "crimson", "maroon", "burgundy",
54
+ "wine red", "cherry red", "scarlet",
55
+
56
+ # Pinks
57
+ "light pink", "hot pink", "coral", "salmon", "rose pink",
58
+ "baby pink", "magenta", "fuchsia",
59
+
60
+ # Oranges
61
+ "dark orange", "light orange", "peach", "tangerine",
62
+ "rust orange", "burnt orange",
63
+
64
+ # Yellows
65
+ "light yellow", "dark yellow", "golden yellow", "lemon yellow",
66
+ "mustard yellow", "cream yellow", "amber",
67
+
68
+ # Greens
69
+ "dark green", "light green", "forest green", "olive green",
70
+ "mint green", "lime green", "emerald green", "sage green",
71
+ "teal", "sea green",
72
+
73
+ # Blues
74
+ "dark blue", "light blue", "navy blue", "royal blue",
75
+ "sky blue", "baby blue", "turquoise", "cyan", "aqua",
76
+ "indigo", "cobalt blue", "denim blue", "steel blue",
77
+
78
+ # Purples
79
+ "dark purple", "light purple", "violet", "lavender",
80
+ "plum", "orchid", "mauve", "lilac",
81
+
82
+ # Browns
83
+ "dark brown", "light brown", "chocolate brown", "tan",
84
+ "beige", "khaki", "caramel", "coffee brown",
85
+ "taupe", "sand", "bronze",
86
+
87
+ # Grays
88
+ "light gray", "dark gray", "charcoal", "silver",
89
+ "ash gray", "slate gray", "stone gray",
90
+
91
+ # Special
92
+ "gold", "silver", "copper", "cream", "ivory", "off-white",
93
+ "wine", "burgundy", "rust", "denim", "multicolor"
94
+ ]
95
+
96
+ # Prepare color prompts for better detection
97
+ COLOR_PROMPTS = [f"a {color} colored clothing item" for color in COLOR_LABELS]
98
 
99
 
100
  # ======================
101
+ # Color Detection using CLIP (Pretrained)
102
  # ======================
103
+ def detect_color_with_clip(image, top_k=3):
104
  """
105
+ Detect color using pretrained CLIP model
106
 
107
+ Returns: List of (color_name, confidence_score)
108
  """
109
  try:
110
+ print("[INFO] 🎨 Detecting colors with CLIP model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ # Prepare inputs
113
+ inputs = color_processor(
114
+ text=COLOR_PROMPTS,
115
+ images=image,
116
+ return_tensors="pt",
117
+ padding=True
118
+ )
 
 
 
119
 
120
+ # Get predictions
121
+ with torch.no_grad():
122
+ outputs = color_model(**inputs)
123
 
124
+ # Calculate probabilities
125
+ logits_per_image = outputs.logits_per_image
126
+ probs = logits_per_image.softmax(dim=1)[0]
 
127
 
128
+ # Get top K colors
129
+ top_probs, top_indices = torch.topk(probs, k=top_k)
 
130
 
131
+ detected_colors = []
132
+ for prob, idx in zip(top_probs, top_indices):
133
+ color_name = COLOR_LABELS[idx.item()]
134
+ confidence = prob.item()
135
 
136
+ # Only include if confidence > 5%
137
+ if confidence > 0.05:
138
+ detected_colors.append((color_name, confidence))
139
+ print(f"[INFO] - {color_name}: {confidence:.1%}")
140
 
141
+ if not detected_colors:
142
+ return [("unknown", 0.0)]
143
 
144
+ return detected_colors
145
 
146
  except Exception as e:
147
  print(f"[ERROR] Color detection failed: {e}")
148
  import traceback
149
  traceback.print_exc()
150
+ return [("detection failed", 0.0)]
151
 
152
 
153
+ # ======================
154
+ # Alternative: Using Image Classification Pipeline
155
+ # ======================
156
+ def detect_color_with_pipeline(image):
157
  """
158
+ Alternative: Using HuggingFace image classification pipeline
159
  """
160
+ try:
161
+ # Load a pretrained color classification model
162
+ # You can replace this with a specific color detection model if available
163
+ classifier = pipeline("image-classification", model="google/vit-base-patch16-224")
164
+
165
+ results = classifier(image)
166
+
167
+ # Filter for color-related predictions
168
+ color_keywords = ['red', 'blue', 'green', 'yellow', 'orange', 'purple',
169
+ 'pink', 'brown', 'black', 'white', 'gray', 'grey']
170
+
171
+ color_results = []
172
+ for result in results:
173
+ label_lower = result['label'].lower()
174
+ for color in color_keywords:
175
+ if color in label_lower:
176
+ color_results.append((color, result['score']))
177
+ break
178
+
179
+ return color_results if color_results else [("multicolor", 0.5)]
180
+
181
+ except Exception as e:
182
+ print(f"[ERROR] Pipeline color detection failed: {e}")
183
+ return [("unknown", 0.0)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
 
186
+ # ======================
187
+ # Clothing Type Detection
188
+ # ======================
189
  def detect_clothing_type(category):
190
+ """Detect if Indian or Western wear"""
 
 
191
  indian_wear = [
192
  'saree', 'kurta', 'salwar', 'lehenga', 'sherwani',
193
  'churidar', 'anarkali', 'kurti', 'dhoti', 'palazzo'
 
202
  return "🌍 Western Wear"
203
 
204
 
205
+ # ======================
206
+ # Color Emoji Mapping
207
+ # ======================
208
  def get_color_emoji(color_name):
209
+ """Get emoji for color"""
210
+ color_lower = color_name.lower()
211
+
212
+ if 'red' in color_lower or 'crimson' in color_lower or 'scarlet' in color_lower:
213
+ return "πŸ”΄"
214
+ elif 'pink' in color_lower or 'rose' in color_lower or 'coral' in color_lower:
215
+ return "🩷"
216
+ elif 'orange' in color_lower or 'peach' in color_lower or 'rust' in color_lower:
217
+ return "🟠"
218
+ elif 'yellow' in color_lower or 'gold' in color_lower or 'amber' in color_lower:
219
+ return "🟑"
220
+ elif 'green' in color_lower or 'olive' in color_lower or 'lime' in color_lower or 'emerald' in color_lower:
221
+ return "🟒"
222
+ elif 'blue' in color_lower or 'navy' in color_lower or 'cyan' in color_lower or 'aqua' in color_lower or 'denim' in color_lower:
223
+ return "πŸ”΅"
224
+ elif 'purple' in color_lower or 'violet' in color_lower or 'lavender' in color_lower or 'plum' in color_lower:
225
+ return "🟣"
226
+ elif 'brown' in color_lower or 'tan' in color_lower or 'beige' in color_lower or 'khaki' in color_lower:
227
+ return "🟀"
228
+ elif 'black' in color_lower or 'dark' in color_lower or 'charcoal' in color_lower:
229
+ return "⚫"
230
+ elif 'white' in color_lower or 'cream' in color_lower or 'ivory' in color_lower:
231
+ return "βšͺ"
232
+ elif 'gray' in color_lower or 'grey' in color_lower or 'silver' in color_lower:
233
+ return "βšͺ"
234
+ else:
235
+ return "🎨"
236
+
237
+
238
+ # ======================
239
+ # Format Color Name
240
+ # ======================
241
+ def format_color_name(color):
242
+ """Format color name to title case"""
243
+ return color.replace('_', ' ').title()
 
 
244
 
245
 
246
  # ======================
247
+ # Main Prediction Function
248
  # ======================
249
  def predict_fashion(image, custom_categories=None):
250
  """
251
+ Classify fashion item + detect color using PRETRAINED models
252
  """
253
  if image is None:
254
  return "⚠️ Please upload an image first!", {}
255
 
256
  try:
257
+ # Step 1: Prepare Categories
258
  if custom_categories and custom_categories.strip():
259
  categories = [cat.strip() for cat in custom_categories.split(",")]
260
  else:
261
  categories = FASHION_CATEGORIES
262
 
263
+ # Step 2: Fashion Item Classification
264
  print("[INFO] πŸ” Classifying fashion item...")
265
+ inputs = fashion_processor(
266
  text=categories,
267
  images=image,
268
  return_tensors="pt",
 
270
  )
271
 
272
  with torch.no_grad():
273
+ outputs = fashion_model(**inputs)
274
 
275
  logits_per_image = outputs.logits_per_image
276
  probs = logits_per_image.softmax(dim=1)[0]
 
279
  top_category = categories[top_prob_idx]
280
  top_confidence = probs[top_prob_idx].item()
281
 
282
+ # Step 3: Color Detection with CLIP (Pretrained)
283
+ detected_colors = detect_color_with_clip(image, top_k=3)
 
284
 
285
+ # Step 4: Clothing Type
286
  clothing_type = detect_clothing_type(top_category)
287
 
288
+ # Step 5: Format Results
289
  result = f"""
290
  ### 🎯 Fashion Item Detected
291
 
 
295
 
296
  ---
297
 
298
+ ### 🎨 Color Detection (Pretrained CLIP Model)
299
 
300
  """
301
 
302
+ # Display detected colors
303
+ if detected_colors and detected_colors[0][0] not in ["unknown", "detection failed"]:
304
  # Primary Color
305
+ primary_color, primary_conf = detected_colors[0]
306
+ primary_formatted = format_color_name(primary_color)
307
  primary_emoji = get_color_emoji(primary_color)
 
308
 
309
+ result += f"**Primary Color:** {primary_emoji} **{primary_formatted}** ✨\n"
310
+ result += f"*Confidence: {primary_conf:.1%}*\n\n"
311
 
312
  # Secondary Colors
313
+ if len(detected_colors) > 1:
314
  result += "**Secondary Colors:**\n"
315
+ for i, (color, conf) in enumerate(detected_colors[1:], 1):
316
+ formatted = format_color_name(color)
317
  emoji = get_color_emoji(color)
318
+ result += f" {i}. {emoji} **{formatted}** ({conf:.1%})\n"
319
  result += "\n"
320
 
321
  # Color Summary
322
+ color_names = [format_color_name(c[0]) for c in detected_colors]
323
+ result += f"**Color Summary:** {', '.join(color_names)} 🌈\n"
324
  else:
325
+ result += f"⚠️ Color detection: {detected_colors[0][0]}\n"
326
 
327
  result += f"""
328
  ---
329
 
330
  ### πŸ“Š Detection Details
331
 
332
+ βœ… **Fashion Model:** Fashion-CLIP (pretrained)
333
+ 🎨 **Color Model:** CLIP Vision Transformer (pretrained)
334
+ πŸ” **Color Database:** {len(COLOR_LABELS)} color categories
335
+ πŸ“Š **Classification Confidence:** {top_confidence:.1%}
336
+ 🧠 **Method:** Zero-shot learning (no training needed)
337
 
338
  ---
339
 
340
  ### πŸ’‘ Styling Suggestions
341
  """
342
 
343
+ # Clothing type suggestions
344
  if "Indian" in clothing_type:
345
  result += """
346
  - Perfect for traditional occasions πŸͺ”
 
354
  - Suitable for casual/formal settings
355
  """
356
 
357
+ # Color-specific styling tips
358
+ if detected_colors and detected_colors[0][0] not in ["unknown", "detection failed"]:
359
+ primary_color = detected_colors[0][0].lower()
360
 
361
+ result += f"\n**Styling Tips for {format_color_name(detected_colors[0][0])}:**\n"
362
 
363
+ if 'black' in primary_color or 'dark' in primary_color:
364
+ result += "- Timeless and elegant ⚫\n- Pairs with everything\n- Perfect for formal occasions\n"
365
+ elif 'white' in primary_color or 'cream' in primary_color or 'ivory' in primary_color:
366
+ result += "- Fresh and clean βšͺ\n- Summer favorite\n- Easy to accessorize\n"
367
+ elif 'gray' in primary_color or 'grey' in primary_color or 'silver' in primary_color:
368
+ result += "- Sophisticated neutral ⚫\n- Professional choice\n- Modern aesthetic\n"
369
+ elif 'red' in primary_color or 'maroon' in primary_color or 'crimson' in primary_color:
370
+ result += "- Bold statement πŸ”΄\n- Confidence booster\n- Pair with neutrals\n"
371
+ elif 'blue' in primary_color or 'navy' in primary_color or 'denim' in primary_color:
372
+ result += "- Classic choice πŸ”΅\n- Versatile wear\n- Calming effect\n"
373
+ elif 'green' in primary_color or 'olive' in primary_color:
374
+ result += "- Natural vibe 🟒\n- Fresh look\n- Great for outdoors\n"
375
+ elif 'yellow' in primary_color or 'gold' in primary_color:
376
+ result += "- Cheerful color 🟑\n- Festive choice\n- Eye-catching\n"
377
+ elif 'pink' in primary_color or 'coral' in primary_color:
378
+ result += "- Soft and feminine 🩷\n- Romantic appeal\n- Party ready\n"
379
+ elif 'purple' in primary_color or 'violet' in primary_color:
380
+ result += "- Royal elegance 🟣\n- Unique choice\n- Sophisticated\n"
381
+ elif 'brown' in primary_color or 'tan' in primary_color or 'beige' in primary_color:
382
+ result += "- Earthy warmth 🟀\n- Natural look\n- Timeless style\n"
383
+ elif 'orange' in primary_color or 'peach' in primary_color:
384
+ result += "- Vibrant energy 🟠\n- Playful choice\n- Summer perfect\n"
385
+
386
+ # Top fashion predictions
387
  top_probs, top_indices = torch.topk(probs, k=min(5, len(categories)))
388
  top_predictions = {}
389
  for prob, idx in zip(top_probs, top_indices):
 
401
  # ======================
402
  # Gradio Interface
403
  # ======================
404
+ with gr.Blocks(title="Fashion AI with Pretrained Color Detection", theme=gr.themes.Soft()) as demo:
405
 
406
  gr.Markdown("""
407
+ # πŸ‘— AI Fashion Classifier with Pretrained Color Detection
408
+ ### Using CLIP Vision Transformer for Zero-Shot Color Recognition
409
  """)
410
 
 
 
411
  gr.Markdown(f"""
412
+ **Fashion Model:** Fashion-CLIP (pretrained)
413
+ **Color Model:** OpenAI CLIP ViT-B/32 (pretrained)
414
+ **Color Categories:** {len(COLOR_LABELS)} colors
415
+ **Method:** Zero-shot learning (no dataset training needed)
416
+
417
+ ### ✨ Why Pretrained Models?
418
+ 1. 🎯 **Highly Accurate** - Trained on millions of images
419
+ 2. ⚑ **Fast** - No preprocessing needed
420
+ 3. 🧠 **Smart** - Understands context and variations
421
+ 4. πŸ”„ **Generalizable** - Works on any clothing type
422
+ 5. πŸ“Š **Reliable** - Consistent results
423
  """)
424
 
425
  with gr.Row():
 
428
 
429
  custom_categories = gr.Textbox(
430
  label="🏷️ Custom Categories (Optional)",
431
+ placeholder="gray shorts, blue jeans, red kurta, white shirt",
432
  info="Comma-separated. Leave empty for 50+ default categories."
433
  )
434
 
435
+ predict_btn = gr.Button("πŸ” Analyze with AI Models", variant="primary", size="lg")
436
 
437
  gr.Markdown("""
438
  **πŸ’‘ Tips:**
439
+ - Clear photos work best
440
+ - Good lighting recommended
441
+ - Single item preferred
 
442
 
443
+ **🎨 Supported Colors ({} types):**
444
+ - Basic: Red, Blue, Green, Yellow, Orange, Purple, Pink, Brown, Black, White, Gray
445
+ - Shades: Dark/Light variations
446
+ - Specific: Navy, Maroon, Teal, Lavender, Beige, etc.
447
+
448
+ **⚑ No Installation Needed:**
449
+ All models are pretrained and ready to use!
450
+ """.format(len(COLOR_LABELS)))
451
 
452
  with gr.Column():
453
+ output_text = gr.Markdown(label="πŸ“‹ AI Analysis Results")
454
+ output_label = gr.Label(label="πŸ“Š Top 5 Item Predictions", num_top_classes=5)
455
 
456
+ # Event Handler
457
  predict_btn.click(
458
  fn=predict_fashion,
459
  inputs=[input_image, custom_categories],
460
  outputs=[output_text, output_label]
461
  )
462
 
463
+ gr.Markdown(f"""
464
  ---
465
+ ### πŸ“ Example Test Cases
466
 
467
+ | Item | Expected Colors |
468
+ |------|----------------|
469
+ | **Gray Shorts** | Gray, Light Gray, Dark Gray, Charcoal |
470
+ | **Denim Jeans** | Denim Blue, Navy Blue, Dark Blue |
471
+ | **Red Saree** | Red, Crimson, Dark Red |
472
+ | **White Shirt** | White, Off-White, Cream |
473
+ | **Black Kurta** | Black, Dark Gray, Charcoal |
474
+ | **Beige Dress** | Beige, Tan, Light Brown, Cream |
475
 
476
  ---
477
 
478
+ ### 🎨 Color Detection Technology
479
 
480
+ **Model:** OpenAI CLIP (Contrastive Language-Image Pre-training)
 
481
 
482
+ **How it works:**
483
+ 1. Image is processed through Vision Transformer
484
+ 2. Compared with {len(COLOR_LABELS)} color text descriptions
485
+ 3. Returns best matching colors with confidence scores
486
+ 4. No background removal needed
487
+ 5. Context-aware (understands "denim blue" vs "sky blue")
488
 
489
+ **Advantages over traditional methods:**
490
+ - βœ… Pretrained on 400M+ image-text pairs
491
+ - βœ… Understands color context (e.g., "denim blue", "burgundy red")
492
+ - βœ… No manual threshold tuning needed
493
+ - βœ… Works on complex patterns and textures
494
+ - βœ… Handles shadows and lighting variations
495
 
496
  ---
497
 
498
  **πŸš€ Powered by:**
499
  - Fashion-CLIP (patrickjohncyh/fashion-clip)
500
+ - OpenAI CLIP ViT-B/32
501
+ - HuggingFace Transformers
502
+ - Zero-shot learning (no training required)
503
+ """.format(len(COLOR_LABELS)))
504
 
505
  # ======================
506
  # Launch
507
  # ======================
508
  if __name__ == "__main__":
509
  print("\n" + "="*60)
510
+ print("πŸš€ FASHION AI WITH PRETRAINED COLOR DETECTION")
511
  print("="*60)
512
+ print(f"βœ… Fashion Model: Fashion-CLIP (loaded)")
513
+ print(f"βœ… Color Model: CLIP ViT-B/32 (loaded)")
514
+ print(f"βœ… Fashion Categories: {len(FASHION_CATEGORIES)}")
515
+ print(f"βœ… Color Categories: {len(COLOR_LABELS)}")
516
+ print(f"βœ… Method: Zero-shot learning")
517
+ print(f"βœ… Background Removal: Not needed (AI handles it)")
518
  print("="*60 + "\n")
519
 
520
  demo.launch(server_name="0.0.0.0", server_port=7860)