| import tempfile |
| import unittest |
|
|
| import numpy as np |
| import torch |
|
|
| from diffusers import DiffusionPipeline |
| from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor |
|
|
|
|
| class AttnAddedKVProcessorTests(unittest.TestCase): |
| def get_constructor_arguments(self, only_cross_attention: bool = False): |
| query_dim = 10 |
|
|
| if only_cross_attention: |
| cross_attention_dim = 12 |
| else: |
| |
| cross_attention_dim = query_dim |
|
|
| return { |
| "query_dim": query_dim, |
| "cross_attention_dim": cross_attention_dim, |
| "heads": 2, |
| "dim_head": 4, |
| "added_kv_proj_dim": 6, |
| "norm_num_groups": 1, |
| "only_cross_attention": only_cross_attention, |
| "processor": AttnAddedKVProcessor(), |
| } |
|
|
| def get_forward_arguments(self, query_dim, added_kv_proj_dim): |
| batch_size = 2 |
|
|
| hidden_states = torch.rand(batch_size, query_dim, 3, 2) |
| encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim) |
| attention_mask = None |
|
|
| return { |
| "hidden_states": hidden_states, |
| "encoder_hidden_states": encoder_hidden_states, |
| "attention_mask": attention_mask, |
| } |
|
|
| def test_only_cross_attention(self): |
| |
|
|
| torch.manual_seed(0) |
|
|
| constructor_args = self.get_constructor_arguments(only_cross_attention=False) |
| attn = Attention(**constructor_args) |
|
|
| self.assertTrue(attn.to_k is not None) |
| self.assertTrue(attn.to_v is not None) |
|
|
| forward_args = self.get_forward_arguments( |
| query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] |
| ) |
|
|
| self_and_cross_attn_out = attn(**forward_args) |
|
|
| |
|
|
| torch.manual_seed(0) |
|
|
| constructor_args = self.get_constructor_arguments(only_cross_attention=True) |
| attn = Attention(**constructor_args) |
|
|
| self.assertTrue(attn.to_k is None) |
| self.assertTrue(attn.to_v is None) |
|
|
| forward_args = self.get_forward_arguments( |
| query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] |
| ) |
|
|
| only_cross_attn_out = attn(**forward_args) |
|
|
| self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all()) |
|
|
|
|
| class DeprecatedAttentionBlockTests(unittest.TestCase): |
| def test_conversion_when_using_device_map(self): |
| pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None) |
|
|
| pre_conversion = pipe( |
| "foo", |
| num_inference_steps=2, |
| generator=torch.Generator("cpu").manual_seed(0), |
| output_type="np", |
| ).images |
|
|
| |
| pipe = DiffusionPipeline.from_pretrained( |
| "hf-internal-testing/tiny-stable-diffusion-pipe", device_map="sequential", safety_checker=None |
| ) |
|
|
| conversion = pipe( |
| "foo", |
| num_inference_steps=2, |
| generator=torch.Generator("cpu").manual_seed(0), |
| output_type="np", |
| ).images |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| |
| pipe.save_pretrained(tmpdir) |
|
|
| |
| pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="sequential", safety_checker=None) |
|
|
| after_conversion = pipe( |
| "foo", |
| num_inference_steps=2, |
| generator=torch.Generator("cpu").manual_seed(0), |
| output_type="np", |
| ).images |
|
|
| self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-5)) |
| self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-5)) |
|
|