Yassine Mhirsi commited on
Commit
f28285b
·
1 Parent(s): d02227d

Add KPA model integration and update configuration for label predictions

Browse files
config.py CHANGED
@@ -14,9 +14,11 @@ PROJECT_ROOT = API_DIR.parent
14
  # Hugging Face configuration
15
  HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "")
16
  HUGGINGFACE_STANCE_MODEL_ID = os.getenv("HUGGINGFACE_STANCE_MODEL_ID")
 
17
 
18
  # Use Hugging Face model ID instead of local path
19
  STANCE_MODEL_ID = HUGGINGFACE_STANCE_MODEL_ID
 
20
 
21
  # API configuration
22
  API_TITLE = "NLP Project API"
 
14
  # Hugging Face configuration
15
  HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "")
16
  HUGGINGFACE_STANCE_MODEL_ID = os.getenv("HUGGINGFACE_STANCE_MODEL_ID")
17
+ HUGGINGFACE_LABEL_MODEL_ID = os.getenv("HUGGINGFACE_LABEL_MODEL_ID")
18
 
19
  # Use Hugging Face model ID instead of local path
20
  STANCE_MODEL_ID = HUGGINGFACE_STANCE_MODEL_ID
21
+ LABEL_MODEL_ID = HUGGINGFACE_LABEL_MODEL_ID
22
 
23
  # API configuration
24
  API_TITLE = "NLP Project API"
main.py CHANGED
@@ -19,8 +19,10 @@ from config import (
19
  API_DESCRIPTION,
20
  API_VERSION,
21
  STANCE_MODEL_ID,
 
22
  HUGGINGFACE_API_KEY,
23
  HUGGINGFACE_STANCE_MODEL_ID,
 
24
  HOST,
25
  PORT,
26
  RELOAD,
@@ -29,7 +31,7 @@ from config import (
29
  CORS_METHODS,
30
  CORS_HEADERS,
31
  )
32
- from services import stance_model_manager
33
  from routes import api_router
34
 
35
  # Configure logging
@@ -51,6 +53,14 @@ async def lifespan(app: FastAPI):
51
  logger.error(f"✗ Failed to load stance model: {str(e)}")
52
  logger.error("⚠️ Stance detection endpoints will not work!")
53
 
 
 
 
 
 
 
 
 
54
  logger.info("✓ API startup complete")
55
  logger.info("https://nlp-debater-project-fastapi-backend-models.hf.space/docs")
56
 
 
19
  API_DESCRIPTION,
20
  API_VERSION,
21
  STANCE_MODEL_ID,
22
+ LABEL_MODEL_ID,
23
  HUGGINGFACE_API_KEY,
24
  HUGGINGFACE_STANCE_MODEL_ID,
25
+ HUGGINGFACE_LABEL_MODEL_ID,
26
  HOST,
27
  PORT,
28
  RELOAD,
 
31
  CORS_METHODS,
32
  CORS_HEADERS,
33
  )
34
+ from services import stance_model_manager, kpa_model_manager
35
  from routes import api_router
36
 
37
  # Configure logging
 
53
  logger.error(f"✗ Failed to load stance model: {str(e)}")
54
  logger.error("⚠️ Stance detection endpoints will not work!")
55
 
56
+ # Load KPA (label) model
57
+ try:
58
+ logger.info(f"Loading KPA model from Hugging Face: {HUGGINGFACE_LABEL_MODEL_ID}")
59
+ kpa_model_manager.load_model(HUGGINGFACE_LABEL_MODEL_ID, HUGGINGFACE_API_KEY)
60
+ except Exception as e:
61
+ logger.error(f"✗ Failed to load KPA model: {str(e)}")
62
+ logger.error("⚠️ KPA/Label prediction endpoints will not work!")
63
+
64
  logger.info("✓ API startup complete")
65
  logger.info("https://nlp-debater-project-fastapi-backend-models.hf.space/docs")
66
 
routes/label.py CHANGED
@@ -11,7 +11,7 @@ from models import (
11
  BatchPredictionResponse
12
  )
13
 
14
- from services import label_model_manage # ton ModelManager clean
15
 
16
  router = APIRouter()
17
  logger = logging.getLogger(__name__)
@@ -23,7 +23,7 @@ async def get_model_info():
23
  Return information about the loaded KPA model.
24
  """
25
  try:
26
- model_info = label_model_manage.get_model_info()
27
 
28
  return {
29
  "model_name": model_info.get("model_name", "unknown"),
@@ -42,7 +42,7 @@ async def get_model_info():
42
  @router.post("/predict", response_model=PredictionResponse, tags=["KPA"])
43
  async def predict_kpa(request: PredictionRequest):
44
  """
45
- Predict keypointargument matching for a single pair.
46
 
47
  - **argument**: The argument text
48
  - **key_point**: The key point to evaluate
@@ -50,7 +50,7 @@ async def predict_kpa(request: PredictionRequest):
50
  Returns the predicted class (apparie / non_apparie) with probabilities.
51
  """
52
  try:
53
- result = label_model_manage.predict(
54
  argument=request.argument,
55
  key_point=request.key_point
56
  )
@@ -88,7 +88,7 @@ async def batch_predict_kpa(request: BatchPredictionRequest):
88
 
89
  for item in request.pairs:
90
  try:
91
- result = label_model_manage.predict(
92
  argument=item.argument,
93
  key_point=item.key_point
94
  )
@@ -115,7 +115,8 @@ async def batch_predict_kpa(request: BatchPredictionRequest):
115
  logger.info(f"Batch KPA prediction completed — {len(results)} items processed")
116
 
117
  return BatchPredictionResponse(
118
- predictions=results
 
119
  )
120
 
121
  except Exception as e:
 
11
  BatchPredictionResponse
12
  )
13
 
14
+ from services import kpa_model_manager
15
 
16
  router = APIRouter()
17
  logger = logging.getLogger(__name__)
 
23
  Return information about the loaded KPA model.
24
  """
25
  try:
26
+ model_info = kpa_model_manager.get_model_info()
27
 
28
  return {
29
  "model_name": model_info.get("model_name", "unknown"),
 
42
  @router.post("/predict", response_model=PredictionResponse, tags=["KPA"])
43
  async def predict_kpa(request: PredictionRequest):
44
  """
45
+ Predict keypoint-argument matching for a single pair.
46
 
47
  - **argument**: The argument text
48
  - **key_point**: The key point to evaluate
 
50
  Returns the predicted class (apparie / non_apparie) with probabilities.
51
  """
52
  try:
53
+ result = kpa_model_manager.predict(
54
  argument=request.argument,
55
  key_point=request.key_point
56
  )
 
88
 
89
  for item in request.pairs:
90
  try:
91
+ result = kpa_model_manager.predict(
92
  argument=item.argument,
93
  key_point=item.key_point
94
  )
 
115
  logger.info(f"Batch KPA prediction completed — {len(results)} items processed")
116
 
117
  return BatchPredictionResponse(
118
+ predictions=results,
119
+ total_processed=len(results)
120
  )
121
 
122
  except Exception as e:
services/__init__.py CHANGED
@@ -1,8 +1,11 @@
1
  """Services for business logic and external integrations"""
2
 
3
  from .stance_model_manager import StanceModelManager, stance_model_manager
 
4
 
5
  __all__ = [
6
  "StanceModelManager",
7
  "stance_model_manager",
 
 
8
  ]
 
1
  """Services for business logic and external integrations"""
2
 
3
  from .stance_model_manager import StanceModelManager, stance_model_manager
4
+ from .label_model_manage import KpaModelManager, kpa_model_manager
5
 
6
  __all__ = [
7
  "StanceModelManager",
8
  "stance_model_manager",
9
+ "KpaModelManager",
10
+ "kpa_model_manager",
11
  ]
services/label_model_manage.py CHANGED
@@ -17,52 +17,46 @@ class KpaModelManager:
17
  self.device = None
18
  self.model_loaded = False
19
  self.max_length = 256
 
20
 
21
- def load_model(self, model_path: str = None):
22
- """Load fine-tuned model weights and tokenizer"""
23
  if self.model_loaded:
24
  logger.info("KPA model already loaded")
25
  return
26
 
27
  try:
28
- # Detect device
 
 
29
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  logger.info(f"Using device: {self.device}")
31
-
32
- # Resolve model path
33
- if model_path is None:
34
- base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
35
- model_path = os.path.join(base_dir, "models", "modele_appariement_rapide.pth")
36
-
37
- logger.info(f"Loading KPA model weights from: {model_path}")
38
-
39
- if not os.path.exists(model_path):
40
- raise FileNotFoundError(f"Model file not found: {model_path}")
41
-
42
- # Load tokenizer + architecture
43
- model_name = "distilbert-base-uncased"
44
  logger.info("Loading tokenizer...")
45
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
46
-
47
- logger.info("Loading base architecture...")
 
 
 
 
48
  self.model = AutoModelForSequenceClassification.from_pretrained(
49
- model_name,
50
- num_labels=2
 
51
  )
52
-
53
- # Load trained weights
54
- checkpoint = torch.load(model_path, map_location=self.device)
55
-
56
- if "model_state_dict" in checkpoint:
57
- self.model.load_state_dict(checkpoint["model_state_dict"])
58
- else:
59
- self.model.load_state_dict(checkpoint)
60
-
61
  self.model.to(self.device)
62
  self.model.eval()
63
 
64
  self.model_loaded = True
65
- logger.info("✓ KPA model loaded successfully!")
66
 
67
  except Exception as e:
68
  logger.error(f"Error loading KPA model: {str(e)}")
@@ -109,12 +103,12 @@ class KpaModelManager:
109
 
110
  def get_model_info(self):
111
  return {
112
- "model_name": self.model_name,
113
  "device": str(self.device),
114
  "max_length": self.max_length,
115
  "num_labels": 2,
116
  "loaded": self.model_loaded
117
- }
118
 
119
 
120
 
 
17
  self.device = None
18
  self.model_loaded = False
19
  self.max_length = 256
20
+ self.model_id = None
21
 
22
+ def load_model(self, model_id: str, api_key: str = None):
23
+ """Load model and tokenizer from Hugging Face"""
24
  if self.model_loaded:
25
  logger.info("KPA model already loaded")
26
  return
27
 
28
  try:
29
+ logger.info(f"Loading KPA model from Hugging Face: {model_id}")
30
+
31
+ # Determine device
32
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
  logger.info(f"Using device: {self.device}")
34
+
35
+ # Store model ID
36
+ self.model_id = model_id
37
+
38
+ # Prepare token for authentication if API key is provided
39
+ token = api_key if api_key else None
40
+
41
+ # Load tokenizer and model from Hugging Face
 
 
 
 
 
42
  logger.info("Loading tokenizer...")
43
+ self.tokenizer = AutoTokenizer.from_pretrained(
44
+ model_id,
45
+ token=token,
46
+ trust_remote_code=True
47
+ )
48
+
49
+ logger.info("Loading model...")
50
  self.model = AutoModelForSequenceClassification.from_pretrained(
51
+ model_id,
52
+ token=token,
53
+ trust_remote_code=True
54
  )
 
 
 
 
 
 
 
 
 
55
  self.model.to(self.device)
56
  self.model.eval()
57
 
58
  self.model_loaded = True
59
+ logger.info("✓ KPA model loaded successfully from Hugging Face!")
60
 
61
  except Exception as e:
62
  logger.error(f"Error loading KPA model: {str(e)}")
 
103
 
104
  def get_model_info(self):
105
  return {
106
+ "model_name": self.model_id,
107
  "device": str(self.device),
108
  "max_length": self.max_length,
109
  "num_labels": 2,
110
  "loaded": self.model_loaded
111
+ }
112
 
113
 
114