Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import AutoProcessor, Pix2StructForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, Pix2StructProcessor, BartConfig,ViTConfig,VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel, AutoTokenizer, AutoModel | |
| from PIL import Image | |
| import torch | |
| import warnings | |
| import re | |
| import json | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| from tqdm import tqdm | |
| import argparse | |
| from scipy import optimize | |
| from typing import Optional | |
| import dataclasses | |
| import editdistance | |
| import itertools | |
| import sys | |
| import time | |
| import logging | |
| import subprocess | |
| import spaces | |
| import openai | |
| import base64 | |
| from io import StringIO | |
| # Git LFS pull λͺ λ Ήμ΄ μ€ν | |
| result = subprocess.run(['git', 'lfs', 'pull'], capture_output=True, text=True) | |
| # λͺ λ Ήμ΄ μ€ν κ²°κ³Ό μΆλ ₯ (μ ν μ¬ν) | |
| if result.returncode == 0: | |
| print("LFS νμΌμ΄ μ±κ³΅μ μΌλ‘ λ€μ΄λ‘λλμμ΅λλ€.") | |
| else: | |
| print(f"μ€λ₯ λ°μ: {result.stderr}") | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger() | |
| warnings.filterwarnings('ignore') | |
| MAX_PATCHES = 512 | |
| # Load the models and processor | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Paths to the models | |
| ko_deplot_model_path = './deplot_model_ver_24.11.21_korean_only(exclude NUUA)_epoch1.bin' | |
| # Load first model ko-deplot | |
| def load_model1(): | |
| processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot') | |
| model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot') | |
| model1.load_state_dict(torch.load(ko_deplot_model_path, map_location="cpu")) | |
| model1.to(torch.device("cuda")) | |
| return processor1, model1 | |
| processor1, model1 = load_model1() | |
| # Function to format output | |
| def format_output(prediction): | |
| return prediction.replace('<0x0A>', '\n') | |
| # First model prediction: ko-deplot | |
| def predict_model1(image): | |
| images = [image] | |
| inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} # Move to GPU | |
| model1.eval() | |
| with torch.no_grad(): | |
| predictions = model1.generate(**inputs, max_new_tokens=4096) | |
| outputs = [processor1.decode(pred, skip_special_tokens=True) for pred in predictions] | |
| formatted_output = format_output(outputs[0]) | |
| return formatted_output | |
| # Set your OpenAI API key | |
| openai.api_key = "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" | |
| # Function to encode the image as base64 | |
| def encode_image(image_path): | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode("utf-8") | |
| # Second model prediction: gpt-4o-mini | |
| def predict_model2(image): | |
| # Encode the uploaded image to base64 | |
| image_data = encode_image(image) | |
| # Prepare the request content | |
| response = openai.ChatCompletion.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "please extract chart title and chart data manually and present them as a table. you should only provide title and table without adding any additional comments such as **Chart Title:** ." | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{image_data}" | |
| } | |
| } | |
| ] | |
| } | |
| ] | |
| ) | |
| # Return the table data from the response | |
| return response.choices[0]["message"]["content"] | |
| def ko_deplot_convert_to_dataframe(label_table_str): #function that converts text generated by ko-deplot to pandas dataframe | |
| lines = label_table_str.strip().split("\n") | |
| data=[] | |
| title= lines[0].split(" | ")[1] | |
| if(len(lines[1].split("|")) == len(lines[2].split("|"))): | |
| headers=lines[1].split(" | ") | |
| for line in lines[2:]: | |
| data.append(line.split(" | ")) | |
| df = pd.DataFrame(data, columns=headers) | |
| return df, title | |
| else: | |
| legend_row=lines[1].split("|") | |
| legend_row.insert(0," ") | |
| for line in lines[2:]: | |
| data.append(line.split(" | ")) | |
| df = pd.DataFrame(data, columns=legend_row) | |
| return df, title | |
| def gpt_convert_to_dataframe(table_text): #function that converts text generated by gpt to pandas dataframe | |
| try: | |
| # Split the text into lines | |
| lines = table_text.strip().split("\n") | |
| title=lines[0] | |
| lines.pop(1) | |
| lines.pop(2) | |
| # Process the remaining lines to create the DataFrame | |
| data = [line.split("|")[1:-1] for line in lines[1:]] # Split by | and remove empty first/last items | |
| dataframe = pd.DataFrame(data[1:], columns=[col.strip() for col in data[0]]) # Use the first row as headers | |
| return dataframe, title | |
| except Exception as e: | |
| return f"Error converting table to DataFrame: {e}" | |
| def real_time_check(image_file): | |
| image = Image.open(image_file) | |
| ko_deplot_generated_txt = predict_model1(image) | |
| parts=ko_deplot_generated_txt.split("\n") | |
| del parts[-1] | |
| ko_deplot_generated_txt="\n".join(parts) | |
| gpt_generated_txt=predict_model2(image_file) | |
| try: | |
| ko_deplot_generated_df, ko_deplot_generated_title=ko_deplot_convert_to_dataframe(ko_deplot_generated_txt) | |
| gpt_generated_df, gpt_generated_title=gpt_convert_to_dataframe(gpt_generated_txt) | |
| return gr.DataFrame(ko_deplot_generated_df, label= ko_deplot_generated_title), gr.DataFrame(gpt_generated_df, label= gpt_generated_title), None,None,0 | |
| except Exception as e: | |
| return None,None,ko_deplot_generated_txt,gpt_generated_txt,1 | |
| flag = 0 #flag to check whether exception happens or not. if flag is 1, it means that exception(generated txt cannot be converted to pandas dataframe) happens. | |
| def inference(image_uploader,mode_selector): | |
| if(mode_selector=="νμΌ μ λ‘λ"): | |
| ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_uploader) | |
| if flag==1: | |
| return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True) | |
| else: | |
| return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False) | |
| else: | |
| ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_files[current_image_index]) | |
| if flag==1: | |
| return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True) | |
| else: | |
| return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False) | |
| def toggle_model(selected_models,flag): | |
| # Create a visibility list initialized to False for all components | |
| visibility = [False] * 6 | |
| # Update visibility based on the selected models | |
| if "VAIV_DePlot" in selected_models: | |
| visibility[4]= True | |
| if flag: | |
| visibility[2]= True | |
| else: | |
| visibility[0]= True | |
| if "gpt-4o-mini" in selected_models: | |
| visibility[5]= True | |
| if flag: | |
| visibility[3]= True | |
| else: | |
| visibility[1]= True | |
| if "all" in selected_models: | |
| visibility[4]=True | |
| visibility[5]=True | |
| if flag: | |
| visibility[2]= True | |
| visibility[3]= True | |
| else: | |
| visibility[0]= True | |
| visibility[1]= True | |
| # Return gr.update for each component with the corresponding visibility status | |
| return tuple(gr.update(visible=v) for v in visibility) | |
| def toggle_mode(mode): | |
| if mode == "νμΌ μ λ‘λ": | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
| else: | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) | |
| def display_image(image_file): | |
| image=Image.open(image_file) | |
| return image, os.path.basename(image_file) | |
| # Function to display the images in the folder sequentially | |
| image_files = [] | |
| current_image_index = 0 | |
| image_files_cnt=0 | |
| def display_folder_images(image_file_path_list): | |
| global image_files, current_image_index,image_files_cnt | |
| image_files = image_file_path_list | |
| image_files_cnt=len(image_files) | |
| current_image_index = 0 | |
| if image_files: | |
| return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=False), gr.update(interactive=True) | |
| return None, "No images found" | |
| def next_image(): | |
| global current_image_index | |
| if image_files: | |
| current_image_index = (current_image_index + 1) | |
| prev_disabled = current_image_index == 0 | |
| next_disabled = current_image_index == (len(image_files) - 1) | |
| return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled) | |
| return None, "No images found" | |
| def prev_image(): | |
| global current_image_index | |
| if image_files: | |
| current_image_index = (current_image_index - 1) | |
| prev_disabled = current_image_index == 0 | |
| next_disabled = current_image_index == (len(image_files) - 1) | |
| return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled) | |
| return None, "No images found" | |
| css = """ | |
| .dataframe-class { | |
| overflow-y: auto !important; /* μ€ν¬λ‘€μ κ°λ₯νκ² */ | |
| height: 250px | |
| } | |
| """ | |
| with gr.Blocks(css=css) as iface: | |
| with gr.Row(): | |
| gr.Markdown("<h1 style='text-align: center;'>SKKU-VAIV Automatic chart understanding evaluation tool</h1>") | |
| gr.Markdown("<hr style='border: 1px solid #ddd;' />") | |
| with gr.Row(): | |
| with gr.Column(): | |
| mode_selector = gr.Radio(["νμΌ μ λ‘λ", "ν΄λ μ λ‘λ"], label="Upload Mode", value="νμΌ μ λ‘λ") | |
| image_uploader = gr.File(file_count="single", file_types=["image"], visible=True) | |
| folder_uploader = gr.File(file_count="directory", file_types=["image"], visible=False, height=50) | |
| model_type=gr.Dropdown(["VAIV_DePlot","gpt-4o-mini","all"],value="VAIV_DePlot",label="model",multiselect=True) | |
| image_displayer = gr.Image(visible=True) | |
| image_name = gr.Text("", visible=True) | |
| with gr.Row(): | |
| prev_button = gr.Button("μ΄μ ", visible=False, interactive=False) | |
| next_button = gr.Button("λ€μ", visible=False, interactive=False) | |
| inference_button = gr.Button("μΆλ‘ ") | |
| with gr.Column(): | |
| md1 = gr.Markdown("# VAIV_DePlot Inference Result") | |
| ko_deplot_generated_df = gr.DataFrame(visible=True, elem_classes="dataframe-class") | |
| ko_deplot_generated_txt = gr.Text(visible=False) | |
| with gr.Column(): | |
| md2 = gr.Markdown("# gpt-4o-mini Inference Result", visible=False) | |
| gpt_generated_df = gr.DataFrame(visible=False, elem_classes="dataframe-class") | |
| gpt_generated_txt = gr.Text(visible=False) | |
| #label_df = gr.DataFrame(visible=False, label="Ground Truth Table", elem_classes="dataframe-class",scale=1) | |
| model_type.change( | |
| toggle_model, | |
| inputs=[model_type, gr.State(flag)], | |
| outputs=[ko_deplot_generated_df,gpt_generated_df,ko_deplot_generated_txt,gpt_generated_txt,md1,md2] | |
| ) | |
| mode_selector.change( | |
| toggle_mode, | |
| inputs=[mode_selector], | |
| outputs=[image_uploader, folder_uploader, prev_button, next_button] | |
| ) | |
| image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name]) | |
| folder_uploader.upload(display_folder_images, inputs=[folder_uploader], outputs=[image_displayer, image_name, prev_button, next_button]) | |
| prev_button.click(prev_image, outputs=[image_displayer, image_name, prev_button, next_button]) | |
| next_button.click(next_image, outputs=[image_displayer, image_name, prev_button, next_button]) | |
| inference_button.click(inference,inputs=[image_uploader,mode_selector],outputs=[ko_deplot_generated_df, gpt_generated_df, ko_deplot_generated_txt, gpt_generated_txt]) | |
| if __name__ == "__main__": | |
| iface.launch(share=True) | |