| import gradio | |
| import subprocess | |
| from PIL import Image | |
| import torch, torch.backends.cudnn, torch.backends.cuda | |
| from min_dalle import MinDalle | |
| from emoji import demojize | |
| import string | |
| def filename_from_text(text: str) -> str: | |
| text = demojize(text, delimiters=['', '']) | |
| text = text.lower().encode('ascii', errors='ignore').decode() | |
| allowed_chars = string.ascii_lowercase + ' ' | |
| text = ''.join(i for i in text.lower() if i in allowed_chars) | |
| text = text[:64] | |
| text = '-'.join(text.strip().split()) | |
| if len(text) == 0: text = 'blank' | |
| return text | |
| def log_gpu_memory(): | |
| print(subprocess.check_output('nvidia-smi').decode('utf-8')) | |
| log_gpu_memory() | |
| model = MinDalle( | |
| is_mega=True, | |
| is_reusable=True, | |
| device='cuda', | |
| dtype=torch.float32 | |
| ) | |
| log_gpu_memory() | |
| def run_model( | |
| text: str, | |
| grid_size: int, | |
| is_seamless: bool, | |
| save_as_png: bool, | |
| temperature: float, | |
| supercondition: str, | |
| top_k: str | |
| ) -> str: | |
| torch.set_grad_enabled(False) | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.deterministic = False | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True | |
| print('text:', text) | |
| print('grid_size:', grid_size) | |
| print('is_seamless:', is_seamless) | |
| print('temperature:', temperature) | |
| print('supercondition:', supercondition) | |
| print('top_k:', top_k) | |
| try: | |
| temperature = float(temperature) | |
| assert(temperature > 1e-6) | |
| except: | |
| raise Exception('Temperature must be a positive nonzero number') | |
| try: | |
| grid_size = int(grid_size) | |
| assert(grid_size <= 5) | |
| assert(grid_size >= 1) | |
| except: | |
| raise Exception('Grid size must be between 1 and 5') | |
| try: | |
| top_k = int(top_k) | |
| assert(top_k <= 16384) | |
| assert(top_k >= 1) | |
| except: | |
| raise Exception('Top k must be between 1 and 16384') | |
| with torch.no_grad(): | |
| image = model.generate_image( | |
| text = text, | |
| seed = -1, | |
| grid_size = grid_size, | |
| is_seamless = bool(is_seamless), | |
| temperature = temperature, | |
| supercondition_factor = float(supercondition), | |
| top_k = top_k, | |
| is_verbose = True | |
| ) | |
| log_gpu_memory() | |
| ext = 'png' if bool(save_as_png) else 'jpg' | |
| filename = filename_from_text(text) | |
| image_path = '{}.{}'.format(filename, ext) | |
| image.save(image_path) | |
| return image_path | |
| demo = gradio.Blocks(analytics_enabled=True) | |
| with demo: | |
| with gradio.Row(): | |
| with gradio.Column(): | |
| input_text = gradio.Textbox( | |
| label='Input Text', | |
| value='Moai statue giving a TED Talk', | |
| lines=3 | |
| ) | |
| run_button = gradio.Button(value='Generate Image').style(full_width=True) | |
| output_image = gradio.Image( | |
| value='examples/moai-statue.jpg', | |
| label='Output Image', | |
| type='file', | |
| interactive=False | |
| ) | |
| with gradio.Column(): | |
| gradio.Markdown('## Settings') | |
| with gradio.Row(): | |
| grid_size = gradio.Slider( | |
| label='Grid Size', | |
| value=5, | |
| minimum=1, | |
| maximum=5, | |
| step=1 | |
| ) | |
| save_as_png = gradio.Checkbox( | |
| label='Output PNG', | |
| value=False | |
| ) | |
| is_seamless = gradio.Checkbox( | |
| label='Seamless', | |
| value=False | |
| ) | |
| gradio.Markdown('#### Advanced') | |
| with gradio.Row(): | |
| temperature = gradio.Number( | |
| label='Temperature', | |
| value=1 | |
| ) | |
| top_k = gradio.Dropdown( | |
| label='Top-k', | |
| choices=[str(2 ** i) for i in range(15)], | |
| value='128' | |
| ) | |
| supercondition = gradio.Dropdown( | |
| label='Super Condition', | |
| choices=[str(2 ** i) for i in range(2, 7)], | |
| value='16' | |
| ) | |
| gradio.Markdown( | |
| """ | |
| #### | |
| - **Input Text**: For long prompts, only the first 64 text tokens will be used to generate the image. | |
| - **Grid Size**: Size of the image grid. 3x3 takes about 15 seconds. | |
| - **Seamless**: Tile images in image token space instead of pixel space. | |
| - **Temperature**: High temperature increases the probability of sampling low scoring image tokens. | |
| - **Top-k**: Each image token is sampled from the top-k scoring tokens. | |
| - **Super Condition**: Higher values can result in better agreement with the text. | |
| """ | |
| ) | |
| gradio.Examples( | |
| examples=[ | |
| ['Rusty Iron Man suit found abandoned in the woods being reclaimed by nature', 3, 'examples/rusty-iron-man.jpg'], | |
| ['Moai statue giving a TED Talk', 5, 'examples/moai-statue.jpg'], | |
| ['Court sketch of Godzilla on trial', 5, 'examples/godzilla-trial.jpg'], | |
| ['lofi nuclear war to relax and study to', 5, 'examples/lofi-nuclear-war.jpg'], | |
| ['Karl Marx slimed at Kids Choice Awards', 4, 'examples/marx-slimed.jpg'], | |
| ['Scientists trying to rhyme orange with banana', 4, 'examples/scientists-rhyme.jpg'], | |
| ['Jesus turning water into wine on Americas Got Talent', 5, 'examples/jesus-talent.jpg'], | |
| ['Elmo in a street riot throwing a Molotov cocktail, hyperrealistic', 5, 'examples/elmo-riot.jpg'], | |
| ['Trail cam footage of gollum eating watermelon', 4, 'examples/gollum.jpg'], | |
| ['Funeral at Whole Foods', 4, 'examples/funeral-whole-foods.jpg'], | |
| ['Singularity, hyperrealism', 5, 'examples/singularity.jpg'], | |
| ['Astronaut riding a horse hyperrealistic', 5, 'examples/astronaut-horse.jpg'], | |
| ['An astronaut walking on Mars next to a Starship rocket, realistic', 5, 'examples/astronaut-mars.jpg'], | |
| ['Nuclear explosion broccoli', 4, 'examples/nuclear-broccoli.jpg'], | |
| ['Dali painting of WALL·E', 5, 'examples/dali-walle.jpg'], | |
| ['Cleopatra checking her iPhone', 4, 'examples/cleopatra-iphone.jpg'], | |
| ], | |
| inputs=[ | |
| input_text, | |
| grid_size, | |
| output_image | |
| ], | |
| examples_per_page=20 | |
| ) | |
| run_button.click( | |
| fn=run_model, | |
| inputs=[ | |
| input_text, | |
| grid_size, | |
| is_seamless, | |
| save_as_png, | |
| temperature, | |
| supercondition, | |
| top_k | |
| ], | |
| outputs=[ | |
| output_image | |
| ] | |
| ) | |
| demo.launch() |