Yassine Mhirsi commited on
Commit
d289997
·
1 Parent(s): 6ba8f9f

Enhance KpaModelManager to load fine-tuned weights from Hugging Face and update requirements to include huggingface_hub

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -0
  2. services/label_model_manage.py +28 -12
requirements.txt CHANGED
@@ -6,4 +6,5 @@ torch>=2.0.0
6
  transformers>=4.35.0
7
  accelerate>=0.24.0
8
  protobuf>=3.20.0
 
9
 
 
6
  transformers>=4.35.0
7
  accelerate>=0.24.0
8
  protobuf>=3.20.0
9
+ huggingface_hub>=0.19.0
10
 
services/label_model_manage.py CHANGED
@@ -3,6 +3,7 @@
3
  import os
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
6
  import logging
7
 
8
  logger = logging.getLogger(__name__)
@@ -20,7 +21,7 @@ class KpaModelManager:
20
  self.model_id = None
21
 
22
  def load_model(self, model_id: str, api_key: str = None):
23
- """Load model and tokenizer from Hugging Face"""
24
  if self.model_loaded:
25
  logger.info("KPA model already loaded")
26
  return
@@ -38,20 +39,35 @@ class KpaModelManager:
38
  # Prepare token for authentication if API key is provided
39
  token = api_key if api_key else None
40
 
41
- # Load tokenizer and model from Hugging Face
42
- logger.info("Loading tokenizer...")
43
- self.tokenizer = AutoTokenizer.from_pretrained(
44
- model_id,
45
- token=token,
46
- trust_remote_code=True
47
- )
48
 
49
- logger.info("Loading model...")
 
50
  self.model = AutoModelForSequenceClassification.from_pretrained(
51
- model_id,
52
- token=token,
53
- trust_remote_code=True
 
 
 
 
 
 
 
54
  )
 
 
 
 
 
 
 
 
 
 
55
  self.model.to(self.device)
56
  self.model.eval()
57
 
 
3
  import os
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ from huggingface_hub import hf_hub_download
7
  import logging
8
 
9
  logger = logging.getLogger(__name__)
 
21
  self.model_id = None
22
 
23
  def load_model(self, model_id: str, api_key: str = None):
24
+ """Load model with weights from Hugging Face repository"""
25
  if self.model_loaded:
26
  logger.info("KPA model already loaded")
27
  return
 
39
  # Prepare token for authentication if API key is provided
40
  token = api_key if api_key else None
41
 
42
+ # Load base tokenizer (distilbert-base-uncased)
43
+ base_model_name = "distilbert-base-uncased"
44
+ logger.info(f"Loading tokenizer from {base_model_name}...")
45
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
 
 
 
46
 
47
+ # Load base model architecture
48
+ logger.info(f"Loading base model architecture from {base_model_name}...")
49
  self.model = AutoModelForSequenceClassification.from_pretrained(
50
+ base_model_name,
51
+ num_labels=2
52
+ )
53
+
54
+ # Download and load fine-tuned weights from Hugging Face
55
+ logger.info(f"Downloading fine-tuned weights from {model_id}...")
56
+ weights_path = hf_hub_download(
57
+ repo_id=model_id,
58
+ filename="modele_appariement_rapide.pth",
59
+ token=token
60
  )
61
+
62
+ logger.info(f"Loading fine-tuned weights from {weights_path}...")
63
+ checkpoint = torch.load(weights_path, map_location=self.device)
64
+
65
+ # Load state dict
66
+ if "model_state_dict" in checkpoint:
67
+ self.model.load_state_dict(checkpoint["model_state_dict"])
68
+ else:
69
+ self.model.load_state_dict(checkpoint)
70
+
71
  self.model.to(self.device)
72
  self.model.eval()
73