from collections.abc import Generator, Iterable from dataclasses import dataclass from enum import StrEnum from itertools import chain from nltk.corpus import wordnet from nltk.metrics import edit_distance import numpy as np import pandas as pd import torch import torch.nn as nn from transformers import ( AutoConfig, AutoModel, AutoTokenizer, ModernBertModel, PreTrainedConfig, PreTrainedModel, ) from transformers.modeling_outputs import SequenceClassifierOutput BATCH_SIZE = 16 class ModelURI(StrEnum): BASE = "answerdotai/ModernBERT-base" LARGE = "answerdotai/ModernBERT-large" @dataclass(slots=True, frozen=True) class LexicalExample: concept: str definition: str @dataclass(slots=True, frozen=True) class PaddedBatch: input_ids: torch.Tensor attention_mask: torch.Tensor class DisamBertCrossEncoder(PreTrainedModel): def __init__(self, config: PreTrainedConfig): super().__init__(config) if config.init_basemodel: self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto") else: self.BaseModel = ModernBertModel(config) config.init_basemodel = False self.loss = nn.BCEWithLogitsLoss() self.post_init() @classmethod def from_base(cls, base_id: ModelURI): config = AutoConfig.from_pretrained(base_id) config.init_basemodel = True config.tokenizer_path = base_id return cls(config) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor | None = None, output_hidden_states: bool = False, output_attentions: bool = False, ) -> SequenceClassifierOutput: base_model_output = self.BaseModel( input_ids, attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) token_vectors = base_model_output.last_hidden_state prev = -1 rows = [] cols = [] for (i,j) in (input_ids == self.config.sep_token_id).nonzero(): if i!=prev: rows.append(i) cols.append(j) prev=i gloss_vectors = token_vectors[rows,cols] logits = torch.einsum("ij,ij->i",token_vectors[:,0],gloss_vectors) return SequenceClassifierOutput( logits=logits, loss=self.loss(logits, labels) if labels is not None else None, hidden_states=base_model_output.hidden_states if output_hidden_states else None, attentions=base_model_output.attentions if output_attentions else None, ) def get_lemma(text: str, synset: wordnet.synset) -> wordnet.lemma: best_score = 1000000 best_lemma = None for lemma in synset.lemmas(): score = edit_distance(text, lemma.name()) if score < best_score: best_score = score best_lemma = lemma return best_lemma class CrossEncoderTagger: def __init__(self,url:str): self.model=AutoModel.from_pretrained(url, device_map="auto", trust_remote_code=True) print(self.model) self.tokenizer=AutoTokenizer.from_pretrained(url) def __call__(self,target:str,sentence:str,candidates:str)->str: text = f"{target}::{sentence}" synsets = [wordnet.synset(candidate) for candidate in candidates] definitions = [f"{get_lemma(target,synset)}::{synset.definition()}" for synset in synsets] sentences = [text]*len(candidates) with self.model.device: tokens = self.tokenizer(sentences,definitions,padding=True,return_tensors="pt") output = self.model(tokens.input_ids, tokens.attention_mask) print(dir(output)) logits = output.logits return candidates[logits.argmax()]