Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import joblib | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import pytorch_lightning as pl | |
| from sklearn.metrics.pairwise import cosine_distances | |
| from sentence_transformers import SentenceTransformer | |
| class IntentClassifier(pl.LightningModule): | |
| def __init__(self, input_dim=384, hidden_dim=256, output_dim=150, lr=1e-3, weight_decay=1e-4): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.fc1 = nn.Linear(input_dim, hidden_dim) | |
| self.bn1 = nn.BatchNorm1d(hidden_dim) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(0.3) | |
| self.fc2 = nn.Linear(hidden_dim, output_dim) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.dropout(x) | |
| return self.fc2(x) | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| logits = self(x) | |
| loss = F.cross_entropy(logits, y) | |
| self.log("train_loss", loss) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| x, y = batch | |
| logits = self(x) | |
| preds = torch.argmax(logits, dim=1) | |
| mask = y != -1 | |
| if mask.sum() > 0: | |
| val_loss = F.cross_entropy(logits[mask], y[mask]) | |
| val_acc = (preds[mask] == y[mask]).float().mean() | |
| else: | |
| val_loss = torch.tensor(0.0, device=self.device) | |
| val_acc = torch.tensor(0.0, device=self.device) | |
| self.log("val_loss", val_loss, prog_bar=True) | |
| self.log("val_acc", val_acc, prog_bar=True) | |
| def configure_optimizers(self): | |
| return torch.optim.Adam( | |
| self.parameters(), | |
| lr=self.hparams.lr, | |
| weight_decay=1e-4 | |
| ) | |
| class IntentClassifierWithOOS: | |
| def __init__(self, embedder, classifier, oos_detector, label_encoder, centroids_dict, oos_threshold=0.5, device="cpu"): | |
| self.embedder = embedder # SentenceTransformer | |
| self.classifier = classifier.eval().to(device) # MLP | |
| self.oos_detector = oos_detector # pipeline sklearn | |
| self.label_encoder = label_encoder # fitted LabelEncoder | |
| self.centroids_dict = centroids_dict # {class_id: centroid} | |
| self.threshold = oos_threshold | |
| self.device = device | |
| def _compute_features(self, embedding, logits, predicted_class): | |
| probs = F.softmax(logits, dim=0).cpu().numpy() | |
| entropy = -np.sum(probs * np.log(probs + 1e-10)) | |
| msp = np.max(probs) | |
| energy = torch.logsumexp(logits, dim=0).item() | |
| # Logit gap | |
| sorted_logits = torch.sort(logits, descending=True).values | |
| logit_gap = (sorted_logits[0] - sorted_logits[1]).item() | |
| # Euclidean distance to class centroid | |
| centroid = self.centroids_dict.get(predicted_class) | |
| dist = np.linalg.norm(embedding - centroid) if centroid is not None else np.nan | |
| # Cosine distance | |
| cos_dist = cosine_distances([embedding], [centroid])[0][0] if centroid is not None else np.nan | |
| norm_emb = np.linalg.norm(embedding) | |
| return np.array([entropy, msp, dist]) | |
| def predict(self, sentence): | |
| # 1. Embedding | |
| embedding = self.embedder.encode(sentence) | |
| embedding = np.array(embedding) | |
| embedding_tensor = torch.tensor(embedding, dtype=torch.float32).unsqueeze(0).to(self.device) | |
| # 2. Intent prediction (MLP) | |
| with torch.no_grad(): | |
| logits = self.classifier(embedding_tensor) | |
| logits = logits.squeeze(0) | |
| probs = F.softmax(logits, dim=0) | |
| predicted_class = torch.argmax(probs).item() | |
| confidence = probs[predicted_class].item() | |
| # 3. Feature extraction | |
| features = self._compute_features(embedding, logits, predicted_class).reshape(1, -1) | |
| # 4. OOS detection | |
| oos_score = self.oos_detector.predict_proba(features)[0, 1] | |
| is_oos = oos_score >= self.threshold | |
| # 5. Output | |
| return { | |
| "intent": "oos" if is_oos else self.label_encoder.inverse_transform([predicted_class])[0], | |
| "is_oos": bool(is_oos), | |
| "confidence": None if is_oos else confidence, | |
| "oos_score": oos_score | |
| } | |
| # Load all saved components from the current directory | |
| best_model = IntentClassifier.load_from_checkpoint( | |
| "intent_classifier.ckpt", | |
| map_location=torch.device("cpu") | |
| ) | |
| oos_detector = joblib.load("oos_detector.pkl") | |
| label_encoder = joblib.load("label_encoder.pkl") | |
| class_centroids = joblib.load("class_centroids.pkl") | |
| best_threshold = joblib.load("oos_threshold.pkl") | |
| print("Model charging") | |
| # Recharger l'embedding model | |
| embedder = SentenceTransformer("intfloat/e5-small-v2") | |
| # Build the full inference model | |
| model = IntentClassifierWithOOS( | |
| embedder=embedder, | |
| classifier=best_model, | |
| oos_detector=oos_detector, | |
| label_encoder=label_encoder, | |
| centroids_dict=class_centroids, | |
| oos_threshold=best_threshold, | |
| device="cpu" | |
| ) | |