Instructions to use BiliSakura/JiT-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/JiT-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| from typing import Callable | |
| import torch | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput | |
| class JiTScheduler(SchedulerMixin, ConfigMixin): | |
| order = 1 | |
| def __init__( | |
| self, | |
| solver: str = "heun", | |
| timestep_start: float = 0.0, | |
| timestep_end: float = 1.0, | |
| ): | |
| if solver not in {"heun", "euler"}: | |
| raise ValueError("solver must be one of: 'heun', 'euler'.") | |
| if timestep_end <= timestep_start: | |
| raise ValueError("timestep_end must be greater than timestep_start.") | |
| self.timesteps = torch.tensor([]) | |
| def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None): | |
| if num_inference_steps < 2: | |
| raise ValueError("num_inference_steps must be >= 2.") | |
| self.timesteps = torch.linspace( | |
| self.config.timestep_start, | |
| self.config.timestep_end, | |
| num_inference_steps + 1, | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| def euler_step( | |
| self, | |
| model_output: torch.Tensor, | |
| timestep: torch.Tensor, | |
| next_timestep: torch.Tensor, | |
| sample: torch.Tensor, | |
| return_dict: bool = True, | |
| ) -> SchedulerOutput | tuple[torch.Tensor]: | |
| prev_sample = sample + (next_timestep - timestep) * model_output | |
| if not return_dict: | |
| return (prev_sample,) | |
| return SchedulerOutput(prev_sample=prev_sample) | |
| def step( | |
| self, | |
| model_output: torch.Tensor, | |
| timestep: torch.Tensor, | |
| next_timestep: torch.Tensor, | |
| sample: torch.Tensor, | |
| model_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, | |
| return_dict: bool = True, | |
| ) -> SchedulerOutput | tuple[torch.Tensor]: | |
| if self.config.solver == "euler": | |
| return self.euler_step(model_output, timestep, next_timestep, sample, return_dict=return_dict) | |
| if model_fn is None: | |
| raise ValueError("model_fn is required when solver='heun'.") | |
| sample_euler = sample + (next_timestep - timestep) * model_output | |
| model_output_next = model_fn(sample_euler, next_timestep) | |
| prev_sample = sample + (next_timestep - timestep) * 0.5 * (model_output + model_output_next) | |
| if not return_dict: | |
| return (prev_sample,) | |
| return SchedulerOutput(prev_sample=prev_sample) | |