Update app.py
Browse files
app.py
CHANGED
|
@@ -129,15 +129,7 @@ if 'generate' not in st.session_state:
|
|
| 129 |
|
| 130 |
# Inizializza inference_tester solo una volta
|
| 131 |
if 'inference_tester' not in st.session_state:
|
| 132 |
-
|
| 133 |
-
st.session_state['inference_tester'] = dani_model(model='thesis_model',
|
| 134 |
-
data_dir='/mimer/NOBACKUP/groups/snic2022-5-277/dmolino/checkpoints/',
|
| 135 |
-
pth=model_load_paths, load_weights=False)
|
| 136 |
-
inference_tester = st.session_state['inference_tester']
|
| 137 |
-
|
| 138 |
-
# Caricamento dei pesi Clip, Optimus, Frontal, Lateral e Text una sola volta
|
| 139 |
-
if 'weights_loaded' not in st.session_state:
|
| 140 |
-
st.session_state['weights_loaded'] = True # Indica che i pesi sono stati caricati
|
| 141 |
|
| 142 |
# Usa inference_tester dalla sessione
|
| 143 |
inference_tester = st.session_state['inference_tester']
|
|
@@ -209,18 +201,12 @@ if st.session_state['step'] == 2:
|
|
| 209 |
|
| 210 |
# Pulsante per provare un esempio
|
| 211 |
with col1:
|
| 212 |
-
if st.button("Inference"):
|
| 213 |
-
st.session_state['step'] = 3 # Passa al passo 3
|
| 214 |
-
st.rerun()
|
| 215 |
-
|
| 216 |
-
# Pulsante per provare un esempio
|
| 217 |
-
with col2:
|
| 218 |
if st.button("Try an example"):
|
| 219 |
st.session_state['step'] = 5 # Passa al passo 5
|
| 220 |
st.rerun()
|
| 221 |
|
| 222 |
# Pulsante per tornare all'inizio
|
| 223 |
-
with
|
| 224 |
if st.button("Return to the beginning"):
|
| 225 |
# Ripristina lo stato della sessione
|
| 226 |
st.session_state['step'] = 1
|
|
@@ -378,79 +364,8 @@ if st.session_state['step'] == 3:
|
|
| 378 |
st.rerun()
|
| 379 |
|
| 380 |
if st.session_state['step'] == 4:
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
conditioning = []
|
| 384 |
-
for inp in st.session_state['inputs']:
|
| 385 |
-
if inp == 'frontal':
|
| 386 |
-
cim = inference_tester.net.clip_encode_vision(st.session_state['frontal'], encode_type='encode_vision').to(device)
|
| 387 |
-
uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['frontal']).to(device),
|
| 388 |
-
encode_type='encode_vision').to(device)
|
| 389 |
-
conditioning.append(torch.cat([uim, cim]))
|
| 390 |
-
elif inp == 'lateral':
|
| 391 |
-
cim = inference_tester.net.clip_encode_vision(st.session_state['lateral'], encode_type='encode_vision').to(device)
|
| 392 |
-
uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['lateral']).to(device),
|
| 393 |
-
encode_type='encode_vision').to(device)
|
| 394 |
-
conditioning.append(torch.cat([uim, cim]))
|
| 395 |
-
elif inp == 'text':
|
| 396 |
-
ctx = inference_tester.net.clip_encode_text(1 * [st.session_state['report']], encode_type='encode_text').to(device)
|
| 397 |
-
utx = inference_tester.net.clip_encode_text(1 * [""], encode_type='encode_text').to(device)
|
| 398 |
-
conditioning.append(torch.cat([utx, ctx]))
|
| 399 |
-
|
| 400 |
-
# Costruzione delle shapes
|
| 401 |
-
shapes = []
|
| 402 |
-
for out in st.session_state['outputs']:
|
| 403 |
-
if out == 'frontal' or out == 'lateral':
|
| 404 |
-
shape = [1, 4, 256 // 8, 256 // 8]
|
| 405 |
-
shapes.append(shape)
|
| 406 |
-
elif out == 'text':
|
| 407 |
-
shape = [1, 768]
|
| 408 |
-
shapes.append(shape)
|
| 409 |
-
|
| 410 |
-
progress_bar = st.progress(0)
|
| 411 |
-
|
| 412 |
-
# Inferenza
|
| 413 |
-
z, _ = inference_tester.sampler.sample(
|
| 414 |
-
steps=50,
|
| 415 |
-
shape=shapes,
|
| 416 |
-
condition=conditioning,
|
| 417 |
-
unconditional_guidance_scale=7.5,
|
| 418 |
-
xtype=st.session_state['outputs'],
|
| 419 |
-
condition_types=st.session_state['inputs'],
|
| 420 |
-
eta=1,
|
| 421 |
-
verbose=False,
|
| 422 |
-
mix_weight={'lateral': 1, 'text': 1, 'frontal': 1},
|
| 423 |
-
progress_bar=progress_bar)
|
| 424 |
-
|
| 425 |
-
# Decoder e visualizzazione dei risultati
|
| 426 |
-
output_cols = st.columns(len(st.session_state['outputs']))
|
| 427 |
-
|
| 428 |
-
# Definire due colonne per le immagini
|
| 429 |
-
col1, col2 = st.columns(2)
|
| 430 |
-
|
| 431 |
-
# Iterare sugli output e assegnare le immagini alle colonne corrispondenti
|
| 432 |
-
for i, out in enumerate(st.session_state['outputs']):
|
| 433 |
-
if out == 'frontal':
|
| 434 |
-
x = inference_tester.net.autokl_decode(z[i])
|
| 435 |
-
x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
|
| 436 |
-
im = x[0].cpu().numpy()
|
| 437 |
-
with col1: # Mostrare la frontal image nella prima colonna
|
| 438 |
-
st.image(im, caption="Generated Frontal Image")
|
| 439 |
-
elif out == 'lateral':
|
| 440 |
-
x = inference_tester.net.autokl_decode(z[i])
|
| 441 |
-
x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
|
| 442 |
-
im = x[0].cpu().numpy()
|
| 443 |
-
with col2: # Mostrare la lateral image nella seconda colonna
|
| 444 |
-
st.image(im, caption="Generated Lateral Image")
|
| 445 |
-
elif out == 'text':
|
| 446 |
-
x = inference_tester.net.optimus_decode(z[i], max_length=100)
|
| 447 |
-
x = [a.tolist() for a in x]
|
| 448 |
-
rec_text = [inference_tester.net.optimus.tokenizer_decoder.decode(a) for a in x]
|
| 449 |
-
rec_text = rec_text[0].replace('<BOS>', '').replace('<EOS>', '')
|
| 450 |
-
st.write(f"Generated Report: {rec_text}")
|
| 451 |
-
|
| 452 |
-
st.write("Generation completed successfully!")
|
| 453 |
-
st.session_state['generate'] = False
|
| 454 |
|
| 455 |
if st.button("Return to the beginning"):
|
| 456 |
# Ripristina lo stato della sessione
|
|
@@ -564,4 +479,4 @@ if st.session_state['step'] == 5:
|
|
| 564 |
st.session_state['frontal_file'] = None
|
| 565 |
st.session_state['lateral_file'] = None
|
| 566 |
st.session_state['report'] = ""
|
| 567 |
-
st.rerun()
|
|
|
|
| 129 |
|
| 130 |
# Inizializza inference_tester solo una volta
|
| 131 |
if 'inference_tester' not in st.session_state:
|
| 132 |
+
st.session_state['inference_tester'] = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
# Usa inference_tester dalla sessione
|
| 135 |
inference_tester = st.session_state['inference_tester']
|
|
|
|
| 201 |
|
| 202 |
# Pulsante per provare un esempio
|
| 203 |
with col1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
if st.button("Try an example"):
|
| 205 |
st.session_state['step'] = 5 # Passa al passo 5
|
| 206 |
st.rerun()
|
| 207 |
|
| 208 |
# Pulsante per tornare all'inizio
|
| 209 |
+
with col2:
|
| 210 |
if st.button("Return to the beginning"):
|
| 211 |
# Ripristina lo stato della sessione
|
| 212 |
st.session_state['step'] = 1
|
|
|
|
| 364 |
st.rerun()
|
| 365 |
|
| 366 |
if st.session_state['step'] == 4:
|
| 367 |
+
st.write("Generation completed successfully!")
|
| 368 |
+
st.session_state['generate'] = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
if st.button("Return to the beginning"):
|
| 371 |
# Ripristina lo stato della sessione
|
|
|
|
| 479 |
st.session_state['frontal_file'] = None
|
| 480 |
st.session_state['lateral_file'] = None
|
| 481 |
st.session_state['report'] = ""
|
| 482 |
+
st.rerun()
|