S01Nour commited on
Commit
2ed0ab3
·
1 Parent(s): c924863

feat: Add `GenerateModelManager` for loading Hugging Face generation models and performing text generation.

Browse files
Files changed (1) hide show
  1. services/generate_model_manager.py +9 -4
services/generate_model_manager.py CHANGED
@@ -58,8 +58,7 @@ class GenerateModelManager:
58
 
59
  def _format_input(self, topic: str, position: str) -> str:
60
  """Format input for the model"""
61
- # Standard format for argument generation
62
- return f"generate argument: topic: {topic} stance: {position}"
63
 
64
  def generate(self, topic: str, position: str, max_length: int = 128, num_beams: int = 4) -> str:
65
  """Generate argument for a topic and position"""
@@ -83,7 +82,10 @@ class GenerateModelManager:
83
  **inputs,
84
  max_length=max_length,
85
  num_beams=num_beams,
86
- early_stopping=True
 
 
 
87
  )
88
 
89
  # Decode
@@ -114,7 +116,10 @@ class GenerateModelManager:
114
  **inputs,
115
  max_length=max_length,
116
  num_beams=num_beams,
117
- early_stopping=True
 
 
 
118
  )
119
 
120
  # Decode batch
 
58
 
59
  def _format_input(self, topic: str, position: str) -> str:
60
  """Format input for the model"""
61
+ return f"topic: {topic} stance: {position}"
 
62
 
63
  def generate(self, topic: str, position: str, max_length: int = 128, num_beams: int = 4) -> str:
64
  """Generate argument for a topic and position"""
 
82
  **inputs,
83
  max_length=max_length,
84
  num_beams=num_beams,
85
+ early_stopping=True,
86
+ no_repeat_ngram_size=3,
87
+ repetition_penalty=2.5,
88
+ length_penalty=1.0
89
  )
90
 
91
  # Decode
 
116
  **inputs,
117
  max_length=max_length,
118
  num_beams=num_beams,
119
+ early_stopping=True,
120
+ no_repeat_ngram_size=3,
121
+ repetition_penalty=2.5,
122
+ length_penalty=1.0
123
  )
124
 
125
  # Decode batch