File size: 4,106 Bytes
3609883 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | from collections.abc import Generator, Iterable
from dataclasses import dataclass
from enum import StrEnum
from itertools import chain
from nltk.corpus import wordnet
from nltk.metrics import edit_distance
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
ModernBertModel,
PreTrainedConfig,
PreTrainedModel,
)
from transformers.modeling_outputs import SequenceClassifierOutput
BATCH_SIZE = 16
class ModelURI(StrEnum):
BASE = "answerdotai/ModernBERT-base"
LARGE = "answerdotai/ModernBERT-large"
@dataclass(slots=True, frozen=True)
class LexicalExample:
concept: str
definition: str
@dataclass(slots=True, frozen=True)
class PaddedBatch:
input_ids: torch.Tensor
attention_mask: torch.Tensor
class DisamBertCrossEncoder(PreTrainedModel):
def __init__(self, config: PreTrainedConfig):
super().__init__(config)
if config.init_basemodel:
self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto")
else:
self.BaseModel = ModernBertModel(config)
config.init_basemodel = False
self.loss = nn.BCEWithLogitsLoss()
self.post_init()
@classmethod
def from_base(cls, base_id: ModelURI):
config = AutoConfig.from_pretrained(base_id)
config.init_basemodel = True
config.tokenizer_path = base_id
return cls(config)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor | None = None,
output_hidden_states: bool = False,
output_attentions: bool = False,
) -> SequenceClassifierOutput:
base_model_output = self.BaseModel(
input_ids,
attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
token_vectors = base_model_output.last_hidden_state
prev = -1
rows = []
cols = []
for (i,j) in (input_ids == self.config.sep_token_id).nonzero():
if i!=prev:
rows.append(i)
cols.append(j)
prev=i
gloss_vectors = token_vectors[rows,cols]
logits = torch.einsum("ij,ij->i",token_vectors[:,0],gloss_vectors)
return SequenceClassifierOutput(
logits=logits,
loss=self.loss(logits, labels) if labels is not None else None,
hidden_states=base_model_output.hidden_states if output_hidden_states else None,
attentions=base_model_output.attentions if output_attentions else None,
)
def get_lemma(text: str, synset: wordnet.synset) -> wordnet.lemma:
best_score = 1000000
best_lemma = None
for lemma in synset.lemmas():
score = edit_distance(text, lemma.name())
if score < best_score:
best_score = score
best_lemma = lemma
return best_lemma
class CrossEncoderTagger:
def __init__(self,url:str):
self.model=AutoModel.from_pretrained(url,
device_map="auto",
trust_remote_code=True)
print(self.model)
self.tokenizer=AutoTokenizer.from_pretrained(url)
def __call__(self,target:str,sentence:str,candidates:str)->str:
text = f"{target}::{sentence}"
synsets = [wordnet.synset(candidate) for candidate in candidates]
definitions = [f"{get_lemma(target,synset)}::{synset.definition()}"
for synset in synsets]
sentences = [text]*len(candidates)
with self.model.device:
tokens = self.tokenizer(sentences,definitions,padding=True,return_tensors="pt")
output = self.model(tokens.input_ids,
tokens.attention_mask)
print(dir(output))
logits = output.logits
return candidates[logits.argmax()]
|