| from typing import Dict, List, Any |
| from PIL import Image |
| from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration |
| import torch |
| import base64 |
| import io |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| """Called when the endpoint starts. Load model and processor.""" |
| self.processor = Pix2StructProcessor.from_pretrained(path) |
| self.model = Pix2StructForConditionalGeneration.from_pretrained(path) |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model.to(self.device) |
| self.model.eval() |
| |
| |
| self.default_header = "Generate underlying data table of the figure below:" |
| |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| Called on every request. |
| |
| Args: |
| data: Dictionary containing: |
| - inputs: base64 encoded image string |
| - parameters (optional): dict with: |
| - header: text prompt for the model (default: DePlot prompt) |
| - max_new_tokens: max generation length (default: 512) |
| |
| Returns: |
| List containing the generated table text |
| """ |
| inputs = data.get("inputs") |
| parameters = data.get("parameters", {}) |
| |
| |
| header_text = ( |
| parameters.get("header") or |
| parameters.get("text") or |
| parameters.get("prompt") or |
| data.get("header") or |
| data.get("text") or |
| data.get("prompt") or |
| self.default_header |
| ) |
| |
| |
| if isinstance(inputs, str): |
| try: |
| image_bytes = base64.b64decode(inputs) |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| except Exception as e: |
| raise ValueError(f"Failed to decode base64 image: {e}") |
| else: |
| raise ValueError("Expected base64 encoded image string in 'inputs'") |
| |
| |
| model_inputs = self.processor( |
| images=image, |
| text=header_text, |
| return_tensors="pt" |
| ).to(self.device) |
| |
| |
| max_new_tokens = parameters.get("max_new_tokens", 512) |
| |
| |
| with torch.no_grad(): |
| predictions = self.model.generate( |
| **model_inputs, |
| max_new_tokens=max_new_tokens |
| ) |
| |
| |
| output_text = self.processor.decode( |
| predictions[0], |
| skip_special_tokens=True |
| ) |
| |
| return [{"generated_text": output_text}] |