File size: 3,343 Bytes
281882b
df27322
 
281882b
df27322
281882b
 
 
 
 
 
df27322
281882b
df27322
 
281882b
 
df27322
 
 
281882b
df27322
 
 
 
 
 
 
281882b
df27322
281882b
 
 
 
df27322
 
281882b
 
df27322
 
281882b
df27322
 
 
 
 
281882b
df27322
 
 
281882b
df27322
 
281882b
 
 
df27322
281882b
df27322
 
281882b
df27322
 
 
 
 
281882b
df27322
281882b
 
 
 
df27322
 
281882b
df27322
 
281882b
df27322
281882b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from PIL import ImageDraw, ImageFont, Image
import cv2
import torch
import numpy as np
import uuid
import spaces
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor

# === Load model (chỉ load 1 lần khi khởi động Space) ===
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd").to("cuda" if torch.cuda.is_available() else "cpu")

SUBSAMPLE = 2  # giảm FPS để tiết kiệm tài nguyên


class StreamObjectDetection:
    @staticmethod
    def draw_bounding_boxes(image, boxes, model, conf_threshold):
        draw = ImageDraw.Draw(image)
        font = ImageFont.load_default()

        for score, label, box in zip(boxes["scores"], boxes["labels"], boxes["boxes"]):
            if score < conf_threshold:
                continue
            x0, y0, x1, y1 = box
            label_text = f"{model.config.id2label[label.item()]}: {score:.2f}"
            draw.rectangle([x0, y0, x1, y1], outline="red", width=3)
            draw.text((x0 + 3, y0 + 3), label_text, fill="white", font=font)

        return image

    @staticmethod
    @spaces.GPU  # Dùng GPU nếu có (ZeroGPU, GPU Cluster, v.v.)
    def stream_object_detection(video, conf_threshold=0.3):
        cap = cv2.VideoCapture(video)
        video_codec = cv2.VideoWriter_fourcc(*"mp4v")
        fps = int(cap.get(cv2.CAP_PROP_FPS)) or 24
        desired_fps = max(1, fps // SUBSAMPLE)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2

        iterating, frame = cap.read()
        n_frames = 0
        output_video_name = f"output_{uuid.uuid4()}.mp4"
        output_video = cv2.VideoWriter(output_video_name, video_codec, desired_fps, (width, height))
        batch = []

        while iterating:
            frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            if n_frames % SUBSAMPLE == 0:
                batch.append(frame)

            # Mỗi 2 giây xử lý một lần
            if len(batch) == 2 * desired_fps:
                inputs = image_processor(images=batch, return_tensors="pt").to(model.device)

                with torch.no_grad():
                    outputs = model(**inputs)

                boxes = image_processor.post_process_object_detection(
                    outputs,
                    target_sizes=torch.tensor([(height, width)] * len(batch)).to(model.device),
                    threshold=conf_threshold,
                )

                for img, box in zip(batch, boxes):
                    pil_image = StreamObjectDetection.draw_bounding_boxes(Image.fromarray(img), box, model, conf_threshold)
                    frame_bgr = np.array(pil_image)[:, :, ::-1]
                    output_video.write(frame_bgr)

                batch = []
                output_video.release()
                yield output_video_name  # Gửi video xử lý từng phần cho Gradio
                output_video_name = f"output_{uuid.uuid4()}.mp4"
                output_video = cv2.VideoWriter(output_video_name, video_codec, desired_fps, (width, height))

            iterating, frame = cap.read()
            n_frames += 1

        cap.release()
        output_video.release()