"""Improved sublinear-attention anchor selection for AGILLM-4. V2 used by the live AGILLM-4 DBlock line: - fixes gathered ALiBi distance so past causal keys receive distance penalty - suppresses local/anchor duplicate candidates before softmax - uses hybrid full-span + recent-tail anchors under the same max-anchor budget - exposes first-token attention sinks as `--sublinear_sinks` - includes optional pooled landmark K/V summaries behind `--sublinear_pooled_landmarks` """ import torch def select_hybrid_anchors(k_len, stride, max_anchors, sinks=4, recent_anchors=-1, device='cpu'): """Full-span + recent-tail landmark positions, plus attention sinks.""" device = torch.device(device) start = stride - 1 if stride <= 0 or max_anchors <= 0 or start >= k_len: anchors = torch.empty(0, device=device, dtype=torch.long) else: all_anchors = torch.arange(start, k_len, stride, device=device, dtype=torch.long) if all_anchors.numel() <= max_anchors: anchors = all_anchors else: if recent_anchors < 0: recent_anchors = max_anchors // 2 recent_budget = min(max(0, int(recent_anchors)), max_anchors) span_budget = max(0, max_anchors - recent_budget) parts = [] if span_budget > 0: sel = torch.linspace(0, all_anchors.numel() - 1, span_budget, device=device).round().long().unique() parts.append(all_anchors[sel]) if recent_budget > 0: parts.append(all_anchors[-recent_budget:]) anchors = torch.cat(parts).unique() if parts else torch.empty(0, device=device, dtype=torch.long) if sinks > 0 and k_len > 0: sink_idx = torch.arange(min(int(sinks), k_len), device=device, dtype=torch.long) anchors = torch.cat([sink_idx, anchors]).unique() if anchors.numel() else sink_idx return anchors def local_anchor_valid(q_pos, anchors, window, k_len): """False where an anchor is already inside that query's local window.""" anchor_idx = anchors.view(1, -1).expand(q_pos.numel(), -1) local_lo = (q_pos - window).clamp_min(0).view(-1, 1) local_hi = (q_pos + window).clamp_max(max(0, k_len - 1)).view(-1, 1) return (anchor_idx < local_lo) | (anchor_idx > local_hi) if __name__ == '__main__': N, window, stride, maxA, sinks, recent = 32768, 128, 128, 128, 4, 64 old_all = list(range(stride - 1, N, stride)) old = sorted(set(range(N - window - 1, N)) | set(old_all[-maxA:])) new = select_hybrid_anchors(N, stride, maxA, sinks, recent).tolist() print(f'N={N} stride={stride} maxA={maxA} sinks={sinks} recent={recent}') print(f'OLD anchor coverage: {min(old)}..{max(old)} first_half={sum(x < N//2 for x in old)}') print(f'NEW anchor coverage: {min(new)}..{max(new)} first_half={sum(x < N//2 for x in new)} recent_tail={sum(x >= N-8192 for x in new)}') q = torch.tensor([N - 1]) print(f'duplicate-suppressed anchors for final query: {int(local_anchor_valid(q, torch.tensor(new), window, N).sum())}/{len(new)}')