Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import yt_dlp | |
| import os | |
| import time | |
| import torch | |
| import transformers | |
| import clip | |
| import numpy as np | |
| import cv2 | |
| import random | |
| from PIL import Image | |
| from multilingual_clip import pt_multilingual_clip | |
| class SearchVideo: | |
| def __init__( | |
| self, | |
| clip_model: str, | |
| text_model: str, | |
| tokenizer, | |
| compose, | |
| ) -> None: | |
| """ | |
| clip_model: CLIP model to use for image embeddings | |
| text_model: text encoder model | |
| """ | |
| self.text_model = text_model | |
| self.tokenizer = tokenizer | |
| self.clip_model = clip_model | |
| self.compose = compose | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def __call__(self, video: str, text: str) -> list: | |
| torch.cuda.empty_cache() | |
| img_list = [] | |
| text_list = [] | |
| frames = self.video2frames_ffmpeg(video) | |
| img_embs = self.get_img_embs(frames) | |
| txt_emb = self.get_txt_embs(text) | |
| # txt_emb = [[t]*len(frames) for t in txt_emb] | |
| txt_emb = txt_emb*len(frames) | |
| logits_per_image = self.compare_embeddings(img_embs, txt_emb) | |
| logits_per_image = [logit.numpy()[0] for logit in logits_per_image] | |
| ind = np.argmax(logits_per_image) | |
| seg_path = self.extract_seg(video, ind) | |
| return ind, seg_path, frames[ind] | |
| def extract_seg(self, video:str, start:int): | |
| start = start if start > 5 else start-5 | |
| start = time.strftime('%H:%M:%S', time.gmtime(start)) | |
| cmd = f'ffmpeg -ss {start} -i "{video}" -t 00:00:02 -vcodec copy -acodec copy -y segment_{start}.mp4' | |
| os.system(cmd) | |
| return f'segment_{start}.mp4' | |
| def video2frames_ffmpeg(self, video: str) -> list: | |
| frames_dir = 'frames' | |
| if not os.path.exists(frames_dir): | |
| os.makedirs(frames_dir) | |
| select = "select='if(eq(n\,0),1,floor(t)-floor(prev_selected_t))'" | |
| os.system(f'ffmpeg -i {video} -r 1 {frames_dir}/output-%04d.jpg') | |
| images = [Image.open(f'{frames_dir}/{f}') for f in sorted(os.listdir(frames_dir))] | |
| os.system(f'rm -rf {frames_dir}') | |
| return images | |
| def video2frames(self, video: str) -> list: | |
| cap = cv2.VideoCapture(video) | |
| num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| images = [] | |
| frames_sec = [i for i in range(0, num_frames, 24*1)] | |
| has_frames,image = cap.read() | |
| frame_count = 0 | |
| while has_frames: | |
| has_frames,image = cap.read() | |
| frame_count += 1 | |
| if has_frames: | |
| if frame_count in frames_sec: | |
| image = Image.fromarray(image) | |
| images.append(image) | |
| return images | |
| def get_img_embs(self, img_list: list) -> list: | |
| """ | |
| takes list of image and calculates clip embeddings with model specified by clip_model | |
| """ | |
| img_input = torch.stack([self.compose(img).to(self.device) | |
| for img in img_list]) | |
| with torch.no_grad(): | |
| image_embs = self.clip_model.encode_image(img_input).float().cpu() | |
| return image_embs | |
| def get_txt_embs(self, text: str) -> torch.Tensor: | |
| "calculates clip emebdding for the text " | |
| with torch.no_grad(): | |
| return self.text_model(text, self.tokenizer) | |
| def compare_embeddings(self, img_embs, txt_embs): | |
| # normalized features | |
| image_features = img_embs / img_embs.norm(dim=-1, keepdim=True) | |
| text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True) | |
| # cosine similarity as logits | |
| logits_per_image = [] | |
| for image_feature in image_features: | |
| logits_per_image.append(image_feature @ text_features.t()) | |
| return logits_per_image | |
| def download_yt_video(url): | |
| ydl_opts = { | |
| 'quiet': True, | |
| "outtmpl": "%(id)s.%(ext)s", | |
| 'format': 'bv*[height<=360][ext=mp4]+ba/b[height<=360] / wv*+ba/w' | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([url]) | |
| return url.split('/')[-1].replace('watch?v=', '')+'.mp4' | |
| clip_model='ViT-B/32' | |
| text_model='M-CLIP/XLM-Roberta-Large-Vit-B-32' | |
| clip_model, compose = clip.load(clip_model) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(text_model) | |
| text_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(text_model) | |
| def search_video(video_url, text, video=None): | |
| search = SearchVideo( | |
| clip_model=clip_model, | |
| text_model=text_model, | |
| tokenizer=tokenizer, | |
| compose=compose | |
| ) | |
| if video !=None: | |
| video_url = None | |
| if video_url: | |
| video = download_yt_video(video_url) | |
| ind, seg_path, img = search(video, text) | |
| start = time.strftime('%H:%M:%S', time.gmtime(ind)) | |
| return f'"{text}" found at {start}', seg_path | |
| title = 'πποΈπ Search inside a video' | |
| description = '''Just enter a search query, a video URL or upload your video and get a 2-sec fragment from the video which is visually closest to you query.''' | |
| examples = [["https://www.youtube.com/watch?v=M93w3TjzVUE", "A dog"]] | |
| iface = gr.Interface( | |
| search_video, | |
| inputs=[gr.Textbox(value="https://www.youtube.com/watch?v=M93w3TjzVUE", label='Video URL'), gr.Textbox(value="a dog", label='Text query'), gr.Video()], | |
| outputs=[gr.Textbox(label="Output"), gr.Video(label="Video segment")], | |
| allow_flagging="never", | |
| title=title, | |
| description=description, | |
| examples=examples | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(show_error=True) |