File size: 4,554 Bytes
2380f6f a453c29 2380f6f a453c29 7218dd0 2380f6f a453c29 2380f6f a453c29 2380f6f a453c29 2380f6f a453c29 2380f6f a453c29 2380f6f a453c29 2380f6f a453c29 2380f6f a453c29 2380f6f a453c29 2380f6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
"""Service for topic extraction from text using LangChain Groq"""
import logging
from typing import Optional, List
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_groq import ChatGroq
from pydantic import BaseModel, Field
from langsmith import traceable
from config import GROQ_API_KEY
logger = logging.getLogger(__name__)
class TopicOutput(BaseModel):
"""Pydantic schema for topic extraction output"""
topic: str = Field(..., description="A specific, detailed topic description")
class TopicService:
"""Service for extracting topics from text arguments"""
def __init__(self):
self.llm = None
self.model_name = "openai/gpt-oss-safeguard-120b" # another model meta-llama/llama-4-scout-17b-16e-instruct
self.initialized = False
def initialize(self, model_name: Optional[str] = None):
"""Initialize the Groq LLM with structured output"""
if self.initialized:
logger.info("Topic service already initialized")
return
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY not found in environment variables")
if model_name:
self.model_name = model_name
try:
logger.info(f"Initializing topic extraction service with model: {self.model_name}")
llm = ChatGroq(
model=self.model_name,
api_key=GROQ_API_KEY,
temperature=0.0,
max_tokens=512,
)
# Bind structured output directly to the model
self.llm = llm.with_structured_output(TopicOutput)
self.initialized = True
logger.info("✓ Topic extraction service initialized successfully")
except Exception as e:
logger.error(f"Error initializing topic service: {str(e)}")
raise RuntimeError(f"Failed to initialize topic service: {str(e)}")
@traceable(name="extract_topic")
def extract_topic(self, text: str) -> str:
"""
Extract a topic from the given text/argument
Args:
text: The input text/argument to extract topic from
Returns:
The extracted topic string
"""
if not self.initialized:
self.initialize()
if not text or not isinstance(text, str):
raise ValueError("Text must be a non-empty string")
text = text.strip()
if len(text) == 0:
raise ValueError("Text cannot be empty")
system_message = """You are an information extraction model.
Extract a topic from the user text. The topic should be a single sentence that captures the main idea of the text in simple english.
Examples:
- Text: "Governments should subsidize electric cars to encourage adoption."
Output: topic="government subsidies for electric vehicle adoption"
- Text: "Raising the minimum wage will hurt small businesses and cost jobs."
Output: topic="raising the minimum wage and its economic impact on small businesses"
"""
try:
result = self.llm.invoke(
[
SystemMessage(content=system_message),
HumanMessage(content=text),
]
)
return result.topic
except Exception as e:
logger.error(f"Error extracting topic: {str(e)}")
raise RuntimeError(f"Topic extraction failed: {str(e)}")
def batch_extract_topics(self, texts: List[str]) -> List[str]:
"""
Extract topics from multiple texts
Args:
texts: List of input texts/arguments
Returns:
List of extracted topics
"""
if not self.initialized:
self.initialize()
if not texts or not isinstance(texts, list):
raise ValueError("Texts must be a non-empty list")
results = []
for text in texts:
try:
topic = self.extract_topic(text)
results.append(topic)
except Exception as e:
logger.error(f"Error extracting topic for text '{text[:50]}...': {str(e)}")
results.append(None) # Or raise, depending on desired behavior
return results
# Initialize singleton instance
topic_service = TopicService()
|