Darius Morawiec commited on
Commit
9afc0f5
·
1 Parent(s): 028f4ca

Refactor model loading and processing

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -145,17 +145,8 @@ with gr.Blocks() as demo:
145
  current_processor = None
146
  current_model_id = None
147
 
148
- @spaces.GPU(duration=300)
149
- def run(
150
- image,
151
- model_id: str,
152
- system_prompt: str,
153
- user_prompt: str,
154
- max_new_tokens: int = 1024,
155
- image_target_size: int | None = None,
156
- ):
157
  global current_model, current_processor, current_model_id
158
- scale = False if model_id.startswith("Qwen/Qwen2.5-VL") else True
159
 
160
  # Only load model if it's different from the currently loaded one
161
  if current_model_id != model_id or current_model is None:
@@ -176,12 +167,14 @@ with gr.Blocks() as demo:
176
  torch.cuda.synchronize()
177
 
178
  # Load new model
 
179
  if model_id.startswith("Qwen/Qwen2-VL"):
180
  model_loader = Qwen2VLForConditionalGeneration
181
  elif model_id.startswith("Qwen/Qwen2.5-VL"):
182
  model_loader = Qwen2_5_VLForConditionalGeneration
183
  elif model_id.startswith("Qwen/Qwen3-VL"):
184
  model_loader = Qwen3VLForConditionalGeneration
 
185
  current_model = model_loader.from_pretrained(
186
  model_id,
187
  torch_dtype="auto",
@@ -190,13 +183,21 @@ with gr.Blocks() as demo:
190
  current_processor = AutoProcessor.from_pretrained(model_id)
191
  current_model_id = model_id
192
 
193
- model = current_model
194
- processor = current_processor
 
 
 
 
 
 
 
 
 
195
 
196
  base64_image = image_to_base64(
197
  scale_image(image, image_target_size) if image_target_size else image
198
  )
199
-
200
  messages = [
201
  {
202
  "role": "user",
@@ -226,7 +227,11 @@ with gr.Blocks() as demo:
226
  )
227
  inputs = inputs.to(DEVICE)
228
 
229
- generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
 
 
 
 
230
  generated_ids_trimmed = [
231
  out_ids[len(in_ids) :]
232
  for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
@@ -241,6 +246,7 @@ with gr.Blocks() as demo:
241
  output_text = repair_json(output_text)
242
  output_json = json.loads(output_text)
243
 
 
244
  x_scale = float(image.width / 1000) if scale else 1.0
245
  y_scale = float(image.height / 1000) if scale else 1.0
246
  bboxes = []
 
145
  current_processor = None
146
  current_model_id = None
147
 
148
+ def load_model(model_id: str):
 
 
 
 
 
 
 
 
149
  global current_model, current_processor, current_model_id
 
150
 
151
  # Only load model if it's different from the currently loaded one
152
  if current_model_id != model_id or current_model is None:
 
167
  torch.cuda.synchronize()
168
 
169
  # Load new model
170
+ model_loader = None
171
  if model_id.startswith("Qwen/Qwen2-VL"):
172
  model_loader = Qwen2VLForConditionalGeneration
173
  elif model_id.startswith("Qwen/Qwen2.5-VL"):
174
  model_loader = Qwen2_5_VLForConditionalGeneration
175
  elif model_id.startswith("Qwen/Qwen3-VL"):
176
  model_loader = Qwen3VLForConditionalGeneration
177
+ assert model_loader is not None, f"Unsupported model ID: {model_id}"
178
  current_model = model_loader.from_pretrained(
179
  model_id,
180
  torch_dtype="auto",
 
183
  current_processor = AutoProcessor.from_pretrained(model_id)
184
  current_model_id = model_id
185
 
186
+ return current_model, current_processor
187
+
188
+ def run(
189
+ image,
190
+ model_id: str,
191
+ system_prompt: str,
192
+ user_prompt: str,
193
+ max_new_tokens: int = 1024,
194
+ image_target_size: int | None = None,
195
+ ):
196
+ model, processor = load_model(model_id)
197
 
198
  base64_image = image_to_base64(
199
  scale_image(image, image_target_size) if image_target_size else image
200
  )
 
201
  messages = [
202
  {
203
  "role": "user",
 
227
  )
228
  inputs = inputs.to(DEVICE)
229
 
230
+ @spaces.GPU(duration=300)
231
+ def _generate(**kwargs):
232
+ return model.generate(**kwargs)
233
+
234
+ generated_ids = _generate(**inputs, max_new_tokens=max_new_tokens)
235
  generated_ids_trimmed = [
236
  out_ids[len(in_ids) :]
237
  for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
246
  output_text = repair_json(output_text)
247
  output_json = json.loads(output_text)
248
 
249
+ scale = False if model_id.startswith("Qwen/Qwen2.5-VL") else True
250
  x_scale = float(image.width / 1000) if scale else 1.0
251
  y_scale = float(image.height / 1000) if scale else 1.0
252
  bboxes = []