| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from einops import rearrange |
| | from torch import nn |
| | from transformers.activations import ACT2FN |
| | from transformers.utils import logging |
| |
|
| | from .configuration_navil_vit import NaViLVisionConfig |
| |
|
| | try: |
| | |
| | from flash_attn import flash_attn_varlen_func |
| | from flash_attn.layers.rotary import apply_rotary_emb |
| | has_flash_attn = True |
| | except: |
| | print('FlashAttention is not installed.') |
| | has_flash_attn = False |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class InternRMSNorm(nn.Module): |
| | def __init__(self, hidden_size, eps=1e-6): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(hidden_size)) |
| | self.variance_epsilon = eps |
| |
|
| | def forward(self, hidden_states): |
| | input_dtype = hidden_states.dtype |
| | hidden_states = hidden_states.to(torch.float32) |
| | variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| | return self.weight * hidden_states.to(input_dtype) |
| |
|
| |
|
| | try: |
| | from apex.normalization import FusedRMSNorm |
| |
|
| | InternRMSNorm = FusedRMSNorm |
| |
|
| | logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm') |
| | except ImportError: |
| | |
| | pass |
| | except Exception: |
| | logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm') |
| | pass |
| |
|
| |
|
| | NORM2FN = { |
| | 'rms_norm': InternRMSNorm, |
| | 'layer_norm': nn.LayerNorm, |
| | } |
| |
|
| |
|
| | class InternVisionRotaryEmbedding(nn.Module): |
| | def __init__(self, dim: int, theta: float = 10000.0) -> None: |
| | super().__init__() |
| | inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) |
| | self.register_buffer("inv_freq", inv_freq, persistent=False) |
| |
|
| | def forward(self, seqlen: int) -> torch.Tensor: |
| | seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
| | freqs = torch.outer(seq, self.inv_freq) |
| | return freqs |
| |
|
| |
|
| | class InternAttention(nn.Module): |
| | """Multi-headed attention from 'Attention Is All You Need' paper""" |
| |
|
| | def __init__(self, config: NaViLVisionConfig): |
| | super().__init__() |
| | self.config = config |
| | self.embed_dim = config.hidden_size |
| | self.num_heads = config.num_attention_heads |
| | self.use_flash_attn = config.use_flash_attn and has_flash_attn |
| | if config.use_flash_attn and not has_flash_attn: |
| | print('Warning: Flash Attention is not available, use_flash_attn is set to False.') |
| | self.head_dim = self.embed_dim // self.num_heads |
| | if self.head_dim * self.num_heads != self.embed_dim: |
| | raise ValueError( |
| | f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:' |
| | f' {self.num_heads}).' |
| | ) |
| |
|
| | self.scale = self.head_dim ** -0.5 |
| | self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias) |
| | self.attn_drop = nn.Dropout(config.attention_dropout) |
| | self.proj_drop = nn.Dropout(config.dropout) |
| |
|
| | self.qk_normalization = config.qk_normalization |
| |
|
| | if self.qk_normalization: |
| | self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) |
| | self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) |
| |
|
| | if self.use_flash_attn: |
| | self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout) |
| | self.proj = nn.Linear(self.embed_dim, self.embed_dim) |
| |
|
| | def _naive_attn(self, x): |
| | B, N, C = x.shape |
| | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv.unbind(0) |
| |
|
| | if self.qk_normalization: |
| | B_, H_, N_, D_ = q.shape |
| | q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) |
| | k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) |
| |
|
| | attn = ((q * self.scale) @ k.transpose(-2, -1)) |
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| | return x |
| |
|
| | def _flash_attn(self, x, key_padding_mask=None, need_weights=False): |
| | qkv = self.qkv(x) |
| | qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) |
| |
|
| | if self.qk_normalization: |
| | q, k, v = qkv.unbind(2) |
| | q = self.q_norm(q.flatten(-2, -1)).view(q.shape) |
| | k = self.k_norm(k.flatten(-2, -1)).view(k.shape) |
| | qkv = torch.stack([q, k, v], dim=2) |
| |
|
| | context, _ = self.inner_attn( |
| | qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False |
| | ) |
| | outs = self.proj(rearrange(context, 'b s h d -> b s (h d)')) |
| | outs = self.proj_drop(outs) |
| | return outs |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states) |
| | return x |
| |
|
| |
|
| | def rotate_half(x): |
| | """Rotates half the hidden dims of the input.""" |
| | x1 = x[..., : x.shape[-1] // 2] |
| | x2 = x[..., x.shape[-1] // 2 :] |
| | return torch.cat((-x2, x1), dim=-1) |
| |
|
| | def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
| | orig_dtype = tensor.dtype |
| | tensor = tensor.float() |
| | cos = freqs.cos() |
| | sin = freqs.sin() |
| | cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() |
| | sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() |
| | output = (tensor * cos) + (rotate_half(tensor) * sin) |
| | output = output.to(orig_dtype) |
| | return output |
| |
|
| |
|
| | class InternVisionSdpaAttention(nn.Module): |
| | def __init__(self, config: NaViLVisionConfig) -> None: |
| | super().__init__() |
| |
|
| | self.config = config |
| |
|
| | dim = config.hidden_size |
| | num_heads = config.num_attention_heads |
| | self.num_heads = num_heads |
| | self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias) |
| | self.proj = nn.Linear(dim, dim) |
| |
|
| | self.qk_normalization = config.qk_normalization |
| |
|
| | if self.qk_normalization: |
| | self.q_norm = InternRMSNorm(dim, eps=config.layer_norm_eps) |
| | self.k_norm = InternRMSNorm(dim, eps=config.layer_norm_eps) |
| |
|
| | self.proj_drop = nn.Dropout(config.dropout) |
| |
|
| | def forward( |
| | self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None |
| | ) -> torch.Tensor: |
| | seq_length = hidden_states.shape[0] |
| | q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
| |
|
| | if self.qk_normalization: |
| | q = self.q_norm(q.flatten(1).view(q.shape)) |
| | k = self.k_norm(k.flatten(1).view(k.shape)) |
| |
|
| | q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| | k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| |
|
| | attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) |
| | for i in range(1, len(cu_seqlens)): |
| | attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True |
| | q = q.transpose(0, 1) |
| | k = k.transpose(0, 1) |
| | v = v.transpose(0, 1) |
| | attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) |
| | attn_output = attn_output.transpose(0, 1) |
| | attn_output = attn_output.reshape(seq_length, -1) |
| | attn_output = self.proj(attn_output) |
| | attn_output = self.proj_drop(attn_output) |
| | return attn_output |
| |
|
| |
|
| | def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
| | tensor_ = tensor.float() |
| | cos = freqs.cos().float() |
| | sin = freqs.sin().float() |
| | output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) |
| | return output |
| |
|
| |
|
| | class InternVisionFlashAttention2(nn.Module): |
| | def __init__(self, config: NaViLVisionConfig) -> None: |
| | super().__init__() |
| | self.config = config |
| |
|
| | dim = config.hidden_size |
| | num_heads = config.num_attention_heads |
| |
|
| | self.num_heads = num_heads |
| | self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias) |
| | self.proj = nn.Linear(dim, dim) |
| |
|
| | self.qk_normalization = config.qk_normalization |
| |
|
| | if self.qk_normalization: |
| | self.q_norm = InternRMSNorm(dim, eps=config.layer_norm_eps) |
| | self.k_norm = InternRMSNorm(dim, eps=config.layer_norm_eps) |
| |
|
| | self.proj_drop = nn.Dropout(config.dropout) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | cu_seqlens: torch.Tensor, |
| | rotary_pos_emb: torch.Tensor = None, |
| | ) -> torch.Tensor: |
| | seq_length = hidden_states.shape[0] |
| | q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
| |
|
| | if self.qk_normalization: |
| | q = self.q_norm(q.flatten(1).view(q.shape)) |
| | k = self.k_norm(k.flatten(1).view(k.shape)) |
| |
|
| | q = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| | k = apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| |
|
| | max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() |
| | attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( |
| | seq_length, -1 |
| | ) |
| | attn_output = self.proj(attn_output) |
| | attn_output = self.proj_drop(attn_output) |
| | return attn_output |
| | |
| |
|
| | class InternMLP(nn.Module): |
| | def __init__(self, config: NaViLVisionConfig): |
| | super().__init__() |
| | self.config = config |
| | self.act = ACT2FN[config.hidden_act] |
| | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) |
| | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | hidden_states = self.fc1(hidden_states) |
| | hidden_states = self.act(hidden_states) |
| | hidden_states = self.fc2(hidden_states) |
| | return hidden_states |
| |
|