| from collections.abc import Generator, Iterable |
| from dataclasses import dataclass |
| from enum import StrEnum |
| from itertools import chain |
|
|
| from nltk.corpus import wordnet |
| from nltk.metrics import edit_distance |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from transformers import ( |
| AutoConfig, |
| AutoModel, |
| AutoTokenizer, |
| ModernBertModel, |
| PreTrainedConfig, |
| PreTrainedModel, |
| ) |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
| BATCH_SIZE = 16 |
|
|
|
|
| class ModelURI(StrEnum): |
| BASE = "answerdotai/ModernBERT-base" |
| LARGE = "answerdotai/ModernBERT-large" |
|
|
|
|
| @dataclass(slots=True, frozen=True) |
| class LexicalExample: |
| concept: str |
| definition: str |
|
|
|
|
| @dataclass(slots=True, frozen=True) |
| class PaddedBatch: |
| input_ids: torch.Tensor |
| attention_mask: torch.Tensor |
|
|
|
|
| class DisamBertCrossEncoder(PreTrainedModel): |
| def __init__(self, config: PreTrainedConfig): |
| super().__init__(config) |
| if config.init_basemodel: |
| self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto") |
| else: |
| self.BaseModel = ModernBertModel(config) |
| config.init_basemodel = False |
| self.loss = nn.BCEWithLogitsLoss() |
| self.post_init() |
| |
| @classmethod |
| def from_base(cls, base_id: ModelURI): |
| config = AutoConfig.from_pretrained(base_id) |
| config.init_basemodel = True |
| config.tokenizer_path = base_id |
| return cls(config) |
| |
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| labels: torch.Tensor | None = None, |
| output_hidden_states: bool = False, |
| output_attentions: bool = False, |
| ) -> SequenceClassifierOutput: |
| base_model_output = self.BaseModel( |
| input_ids, |
| attention_mask, |
| output_hidden_states=output_hidden_states, |
| output_attentions=output_attentions, |
| ) |
| token_vectors = base_model_output.last_hidden_state |
| prev = -1 |
| rows = [] |
| cols = [] |
| for (i,j) in (input_ids == self.config.sep_token_id).nonzero(): |
| if i!=prev: |
| rows.append(i) |
| cols.append(j) |
| prev=i |
| gloss_vectors = token_vectors[rows,cols] |
| logits = torch.einsum("ij,ij->i",token_vectors[:,0],gloss_vectors) |
| return SequenceClassifierOutput( |
| logits=logits, |
| loss=self.loss(logits, labels) if labels is not None else None, |
| hidden_states=base_model_output.hidden_states if output_hidden_states else None, |
| attentions=base_model_output.attentions if output_attentions else None, |
| ) |
| |
| def get_lemma(text: str, synset: wordnet.synset) -> wordnet.lemma: |
| best_score = 1000000 |
| best_lemma = None |
| for lemma in synset.lemmas(): |
| score = edit_distance(text, lemma.name()) |
| if score < best_score: |
| best_score = score |
| best_lemma = lemma |
| return best_lemma |
| |
| class CrossEncoderTagger: |
| def __init__(self,url:str): |
| self.model=AutoModel.from_pretrained(url, |
| device_map="auto", |
| trust_remote_code=True) |
| print(self.model) |
| self.tokenizer=AutoTokenizer.from_pretrained(url) |
| |
| def __call__(self,target:str,sentence:str,candidates:str)->str: |
| text = f"{target}::{sentence}" |
| synsets = [wordnet.synset(candidate) for candidate in candidates] |
| definitions = [f"{get_lemma(target,synset)}::{synset.definition()}" |
| for synset in synsets] |
| sentences = [text]*len(candidates) |
| with self.model.device: |
| tokens = self.tokenizer(sentences,definitions,padding=True,return_tensors="pt") |
| output = self.model(tokens.input_ids, |
| tokens.attention_mask) |
| print(dir(output)) |
| logits = output.logits |
| return candidates[logits.argmax()] |
| |
| |
| |