File size: 5,490 Bytes
80e6c51
4923b8f
b15c679
4923b8f
 
8f48a77
80e6c51
 
020ca85
8f48a77
4923b8f
 
 
 
 
020ca85
 
80e6c51
8f48a77
b15c679
8f48a77
fef4d81
8f48a77
b15c679
8f48a77
 
80e6c51
8f48a77
020ca85
80e6c51
 
 
b15c679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80e6c51
 
 
 
 
 
 
 
 
 
 
 
 
 
b15c679
8f48a77
4923b8f
020ca85
8f48a77
 
 
 
 
 
 
80e6c51
b15c679
 
8f48a77
 
 
 
80e6c51
 
b15c679
80e6c51
b15c679
 
8f48a77
 
 
 
80e6c51
8f48a77
80e6c51
 
 
 
 
 
 
 
66a1189
80e6c51
57ec10d
4923b8f
57ec10d
 
80e6c51
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import spaces
import gradio as gr
import torch
import gc, os, uuid, json
from PIL import PngImagePlugin
from diffusers import DiffusionPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
if os.getenv("SPACES_ZERO_GPU", None):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.set_float32_matmul_precision("high") # https://pytorch.org/blog/accelerating-generative-ai-3/


def load_pipeline():
    #vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
    pipe = DiffusionPipeline.from_pretrained(
        #"John6666/rae-diffusion-xl-v2-sdxl-spo-pcm",
        "Raelina/Raehoshi-illust-XL-8",
        #custom_pipeline="lpw_stable_diffusion_xl",
        #custom_pipeline="nyanko7/sdxl_smoothed_energy_guidance",
        torch_dtype=dtype,
        #vae=vae,
    )
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
    pipe.to("cpu")
    return pipe


def token_auto_concat_embeds(pipe, positive, negative):
    max_length = pipe.tokenizer.model_max_length
    positive_length = pipe.tokenizer(positive, return_tensors="pt").input_ids.shape[-1]
    negative_length = pipe.tokenizer(negative, return_tensors="pt").input_ids.shape[-1]
    
    print(f'Token length is model maximum: {max_length}, positive length: {positive_length}, negative length: {negative_length}.')
    if max_length < positive_length or max_length < negative_length:
        print('Concatenated embedding.')
        if positive_length > negative_length:
            positive_ids = pipe.tokenizer(positive, return_tensors="pt").input_ids.to("cuda")
            negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=positive_ids.shape[-1], return_tensors="pt").input_ids.to("cuda")
        else:
            negative_ids = pipe.tokenizer(negative, return_tensors="pt").input_ids.to("cuda")  
            positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=negative_ids.shape[-1],  return_tensors="pt").input_ids.to("cuda")
    else:
        positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=max_length,  return_tensors="pt").input_ids.to("cuda")
        negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=max_length, return_tensors="pt").input_ids.to("cuda")
    
    positive_concat_embeds = []
    negative_concat_embeds = []
    for i in range(0, positive_ids.shape[-1], max_length):
        positive_concat_embeds.append(pipe.text_encoder(positive_ids[:, i: i + max_length])[0])
        negative_concat_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])
    
    positive_prompt_embeds = torch.cat(positive_concat_embeds, dim=1)
    negative_prompt_embeds = torch.cat(negative_concat_embeds, dim=1)
    return positive_prompt_embeds, negative_prompt_embeds


def save_image(image, metadata, output_dir):
    filename = str(uuid.uuid4()) + ".png"
    os.makedirs(output_dir, exist_ok=True)
    filepath = os.path.join(output_dir, filename)
    metadata_str = json.dumps(metadata)
    info = PngImagePlugin.PngInfo()
    info.add_text("metadata", metadata_str)
    image.save(filepath, "PNG", pnginfo=info)
    return filepath


pipe = load_pipeline()


@torch.inference_mode()
@spaces.GPU(duration=15)
def generate_image(prompt, neg_prompt, progress=gr.Progress(track_tqdm=True)):
    pipe.to(device)
    #prompt += ", masterpiece, best quality, very aesthetic, absurdres"
    #neg_prompt += "bad hands, bad feet, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], photo, deformed, disfigured, low contrast, photo, deformed, disfigured, low contrast"
    neg_prompt += "bad quality, worst quality, poorly drawn, sketch, multiple views, bad anatomy, bad hands, missing fingers, extra fingers, extra digits, fewer digits, signature, watermark, username"
    width = 1024
    height = 1024
    cfg = 6.0
    steps = 28
    metadata = {
        "prompt": prompt,
        "negative_prompt": neg_prompt,
        "resolution": f"{width} x {height}",
        "guidance_scale": cfg,
        "num_inference_steps": steps,
        "sampler": "Euler a",
    }
    try: 
        #positive_embeds, negative_embeds = token_auto_concat_embeds(pipe, prompt, neg_prompt)
        images = pipe(
            prompt=prompt,
            negative_prompt=neg_prompt,
            width=width,
            height=height,
            guidance_scale=cfg,# seg_scale=3.0, seg_applied_layers=["mid"],
            num_inference_steps=steps,
            output_type="pil",
            #clip_skip=1,
        ).images
        if images:
            image_paths = [
                save_image(image, metadata, "./outputs")
                for image in images
            ]
        return image_paths
    except Exception as e:
        print(e)
        return []
    finally:
        pipe.to("cpu")
        torch.cuda.empty_cache()
        gc.collect()