vivek01 commited on
Commit
b9bce35
·
verified ·
1 Parent(s): 314ea92

Update handler.py to handle lists

Browse files
Files changed (1) hide show
  1. handler.py +41 -31
handler.py CHANGED
@@ -26,39 +26,49 @@ class EndpointHandler:
26
  self.hate_speech_encoder = joblib.load(os.path.join(model_path, 'hate_speech_encoder.pkl'))
27
 
28
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
29
- # Preprocess input: extract the text from the request payload
30
- text = data.get('inputs')
31
-
32
- # Tokenize the input text
33
- inputs = self.tokenizer(text, return_tensors='pt', max_length=256, truncation=True, padding=True)
34
- if 'token_type_ids' in inputs:
35
- del inputs['token_type_ids']
36
- inputs = {key: val.to(self.device) for key, val in inputs.items()}
37
-
38
- # Run the input through the model
39
- with torch.no_grad():
40
- outputs = self.model(**inputs)
41
- emotion_logits = outputs.get('emotion')
42
- polarity_logits = outputs.get('polarity')
43
- hate_speech_logits = outputs.get('hate_speech')
44
-
45
- # Decode predictions from logits using argmax and the label encoders
46
- emotion_preds = torch.argmax(emotion_logits, dim=1).cpu().numpy().tolist()
47
- polarity_preds = torch.argmax(polarity_logits, dim=1).cpu().numpy().tolist()
48
- hate_speech_preds = torch.argmax(hate_speech_logits, dim=1).cpu().numpy().tolist()
49
-
50
- # Inverse transform the predictions to get human-readable labels
51
- decoded_emotions = self.emotion_encoder.inverse_transform(emotion_preds).tolist()
52
- decoded_polarities = self.polarity_encoder.inverse_transform(polarity_preds).tolist()
53
- decoded_hate_speech = self.hate_speech_encoder.inverse_transform(hate_speech_preds).tolist()
54
-
55
- # Return the decoded results as a dictionary
56
- return {
57
- "emotions": decoded_emotions,
58
- "polarities": decoded_polarities,
59
- "hate_speech": decoded_hate_speech
60
  }
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def load_model(self, model_path):
63
  #Load model weights from the specified path
64
  self.load_state_dict(torch.load(model_path))
 
26
  self.hate_speech_encoder = joblib.load(os.path.join(model_path, 'hate_speech_encoder.pkl'))
27
 
28
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
29
+ # Extract the list of texts
30
+ texts = data.get('inputs', [])
31
+
32
+ # Batch processing for large inputs
33
+ batch_size = 32
34
+ results = {
35
+ "emotions": [],
36
+ "polarities": [],
37
+ "hate_speech": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  }
39
 
40
+ for i in range(0, len(texts), batch_size):
41
+ batch_texts = texts[i:i+batch_size]
42
+
43
+ # Tokenize the input text
44
+ inputs = self.tokenizer(text, return_tensors='pt', max_length=256, truncation=True, padding=True)
45
+ if 'token_type_ids' in inputs:
46
+ del inputs['token_type_ids']
47
+ inputs = {key: val.to(self.device) for key, val in inputs.items()}
48
+
49
+ # Run the input through the model
50
+ with torch.no_grad():
51
+ outputs = self.model(**inputs)
52
+ emotion_logits = outputs.get('emotion')
53
+ polarity_logits = outputs.get('polarity')
54
+ hate_speech_logits = outputs.get('hate_speech')
55
+
56
+ # Decode predictions from logits using argmax and the label encoders
57
+ emotion_preds = torch.argmax(emotion_logits, dim=1).cpu().numpy().tolist()
58
+ polarity_preds = torch.argmax(polarity_logits, dim=1).cpu().numpy().tolist()
59
+ hate_speech_preds = torch.argmax(hate_speech_logits, dim=1).cpu().numpy().tolist()
60
+
61
+ # Inverse transform the predictions to get human-readable labels
62
+ decoded_emotions = self.emotion_encoder.inverse_transform(emotion_preds).tolist()
63
+ decoded_polarities = self.polarity_encoder.inverse_transform(polarity_preds).tolist()
64
+ decoded_hate_speech = self.hate_speech_encoder.inverse_transform(hate_speech_preds).tolist()
65
+
66
+ results["emotions"].extend(decoded_emotions)
67
+ results["polarities"].extend(decoded_polarities)
68
+ results["hate_speech"].extend(decoded_hate_speech)
69
+
70
+ return results
71
+
72
  def load_model(self, model_path):
73
  #Load model weights from the specified path
74
  self.load_state_dict(torch.load(model_path))