| """ |
| Evaluation script for trained model with comprehensive analysis |
| """ |
| import argparse |
| import sys |
| import os |
| import numpy as np |
| import pandas as pd |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer |
|
|
| |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
| from src import ( |
| load_config, |
| compute_metrics_factory, |
| plot_confusion_matrix, |
| print_classification_report |
| ) |
| from src.data_loader import prepare_datasets_for_training |
|
|
|
|
| def analyze_errors( |
| test_dataset, |
| predictions: np.ndarray, |
| labels: np.ndarray, |
| id2label: dict, |
| tokenizer, |
| top_n: int = 10 |
| ) -> pd.DataFrame: |
| """ |
| Analyze misclassified examples. |
| |
| Args: |
| test_dataset: Test dataset |
| predictions: Predicted labels |
| labels: True labels |
| id2label: Label mapping |
| tokenizer: Tokenizer to decode text |
| top_n: Number of examples to show per error type |
| |
| Returns: |
| DataFrame with error analysis |
| """ |
| errors = [] |
| for i, (pred, true_label) in enumerate(zip(predictions, labels)): |
| if pred != true_label: |
| |
| |
| errors.append({ |
| 'index': i, |
| 'true_label': id2label[true_label], |
| 'predicted_label': id2label[pred], |
| 'error_type': f"{id2label[true_label]} -> {id2label[pred]}" |
| }) |
| |
| error_df = pd.DataFrame(errors) |
| if len(error_df) > 0: |
| print(f"\nError Analysis:") |
| print(f"Total errors: {len(error_df)}") |
| print(f"\nError type distribution:") |
| print(error_df['error_type'].value_counts()) |
| |
| return error_df |
|
|
|
|
| def evaluate_model( |
| model_path: str, |
| config_path: str = "config.yaml", |
| save_plots: bool = True |
| ): |
| """ |
| Evaluate trained model on test set with comprehensive analysis. |
| |
| Args: |
| model_path: Path to the trained model |
| config_path: Path to configuration file |
| save_plots: Whether to save visualization plots |
| """ |
| print("=" * 60) |
| print("Model Evaluation") |
| print("=" * 60) |
| |
| |
| config = load_config(config_path) |
| |
| |
| output_dir = config['training'].get('output_dir', './results') |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| print("\n[1/5] Loading datasets...") |
| tokenized_datasets, label2id, id2label, _ = prepare_datasets_for_training(config_path) |
| test_dataset = tokenized_datasets['test'] |
| print(f"✓ Test samples: {len(test_dataset)}") |
| |
| |
| print("\n[2/5] Loading trained model...") |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = AutoModelForSequenceClassification.from_pretrained(model_path) |
| print(f"✓ Model loaded from {model_path}") |
| |
| |
| print("\n[3/5] Running evaluation...") |
| compute_metrics_fn = compute_metrics_factory(id2label) |
| trainer = Trainer( |
| model=model, |
| tokenizer=tokenizer, |
| compute_metrics=compute_metrics_fn |
| ) |
| |
| |
| predictions_output = trainer.predict(test_dataset) |
| predictions = np.argmax(predictions_output.predictions, axis=1) |
| labels = predictions_output.label_ids |
| |
| |
| print("\n[4/5] Computing detailed metrics...") |
| print("\n" + "=" * 60) |
| print("Test Set Results") |
| print("=" * 60) |
| |
| metrics = predictions_output.metrics |
| |
| |
| print("\nOverall Metrics:") |
| overall_metrics = ['accuracy', 'f1_weighted', 'f1_macro', 'precision_weighted', 'recall_weighted'] |
| for metric in overall_metrics: |
| key = f'test_{metric}' |
| if key in metrics: |
| print(f" {metric.replace('_', ' ').title()}: {metrics[key]:.4f}") |
| |
| |
| print("\nPer-Class Metrics:") |
| label_names = [id2label[i] for i in range(len(id2label))] |
| for label_name in label_names: |
| precision_key = f'test_precision_{label_name}' |
| recall_key = f'test_recall_{label_name}' |
| f1_key = f'test_f1_{label_name}' |
| if precision_key in metrics: |
| print(f"\n {label_name.upper()}:") |
| print(f" Precision: {metrics[precision_key]:.4f}") |
| print(f" Recall: {metrics[recall_key]:.4f}") |
| print(f" F1-Score: {metrics[f1_key]:.4f}") |
| print(f" Support: {metrics.get(f'test_support_{label_name}', 'N/A')}") |
| |
| |
| print("\n" + "=" * 60) |
| print_classification_report(labels, predictions, label_names) |
| |
| |
| print("\n[5/5] Generating visualizations...") |
| if save_plots: |
| plot_confusion_matrix( |
| labels, |
| predictions, |
| label_names, |
| save_path=os.path.join(output_dir, "confusion_matrix.png"), |
| normalize=False |
| ) |
| |
| |
| plot_confusion_matrix( |
| labels, |
| predictions, |
| label_names, |
| save_path=os.path.join(output_dir, "confusion_matrix_normalized.png"), |
| normalize=True |
| ) |
| |
| |
| error_df = analyze_errors(test_dataset, predictions, labels, id2label, tokenizer) |
| if len(error_df) > 0 and save_plots: |
| error_path = os.path.join(output_dir, "error_analysis.csv") |
| error_df.to_csv(error_path, index=False) |
| print(f"✓ Error analysis saved to {error_path}") |
| |
| print("\n" + "=" * 60) |
| print("Evaluation Complete! 🎉") |
| print("=" * 60) |
| print(f"\nResults saved to: {output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Evaluate trained model") |
| parser.add_argument( |
| "--model-path", |
| type=str, |
| default="./results/final_model", |
| help="Path to the trained model" |
| ) |
| parser.add_argument( |
| "--config", |
| type=str, |
| default="config.yaml", |
| help="Path to configuration file" |
| ) |
| parser.add_argument( |
| "--no-plots", |
| action="store_true", |
| help="Skip generating visualization plots" |
| ) |
| args = parser.parse_args() |
| |
| evaluate_model(args.model_path, args.config, save_plots=not args.no_plots) |
|
|