File size: 4,554 Bytes
2380f6f
 
 
 
 
 
a453c29
2380f6f
 
 
 
 
 
a453c29
 
 
 
7218dd0
2380f6f
 
a453c29
2380f6f
 
 
a453c29
2380f6f
 
 
a453c29
2380f6f
 
 
 
 
 
 
 
 
 
 
 
 
a453c29
2380f6f
 
 
 
 
 
a453c29
 
2380f6f
 
 
 
 
 
 
 
 
 
 
a453c29
2380f6f
 
 
 
 
a453c29
2380f6f
 
 
 
 
 
 
 
 
 
 
a453c29
 
 
 
 
 
 
 
 
 
2380f6f
 
 
 
 
 
 
 
 
a453c29
2380f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Service for topic extraction from text using LangChain Groq"""

import logging
from typing import Optional, List
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_groq import ChatGroq
from pydantic import BaseModel, Field
from langsmith import traceable

from config import GROQ_API_KEY

logger = logging.getLogger(__name__)


class TopicOutput(BaseModel):
    """Pydantic schema for topic extraction output"""
    topic: str = Field(..., description="A specific, detailed topic description")


class TopicService:
    """Service for extracting topics from text arguments"""
    
    def __init__(self):
        self.llm = None
        self.model_name = "openai/gpt-oss-safeguard-120b"  # another model meta-llama/llama-4-scout-17b-16e-instruct
        self.initialized = False
        
    def initialize(self, model_name: Optional[str] = None):
        """Initialize the Groq LLM with structured output"""
        if self.initialized:
            logger.info("Topic service already initialized")
            return
            
        if not GROQ_API_KEY:
            raise ValueError("GROQ_API_KEY not found in environment variables")
        
        if model_name:
            self.model_name = model_name
            
        try:
            logger.info(f"Initializing topic extraction service with model: {self.model_name}")
            
            llm = ChatGroq(
                model=self.model_name,
                api_key=GROQ_API_KEY,
                temperature=0.0,
                max_tokens=512,
            )
            
            # Bind structured output directly to the model
            self.llm = llm.with_structured_output(TopicOutput)
            self.initialized = True
            
            logger.info("✓ Topic extraction service initialized successfully")
            
        except Exception as e:
            logger.error(f"Error initializing topic service: {str(e)}")
            raise RuntimeError(f"Failed to initialize topic service: {str(e)}")
    
    @traceable(name="extract_topic")
    def extract_topic(self, text: str) -> str:
        """
        Extract a topic from the given text/argument
        
        Args:
            text: The input text/argument to extract topic from
            
        Returns:
            The extracted topic string
        """
        if not self.initialized:
            self.initialize()
        
        if not text or not isinstance(text, str):
            raise ValueError("Text must be a non-empty string")
        
        text = text.strip()
        if len(text) == 0:
            raise ValueError("Text cannot be empty")
        
        system_message = """You are an information extraction model.
Extract a topic from the user text. The topic should be a single sentence that captures the main idea of the text in simple english.

Examples:
- Text: "Governments should subsidize electric cars to encourage adoption."
  Output: topic="government subsidies for electric vehicle adoption"

- Text: "Raising the minimum wage will hurt small businesses and cost jobs."
  Output: topic="raising the minimum wage and its economic impact on small businesses"
"""
        
        try:
            result = self.llm.invoke(
                [
                    SystemMessage(content=system_message),
                    HumanMessage(content=text),
                ]
            )
            
            return result.topic
            
        except Exception as e:
            logger.error(f"Error extracting topic: {str(e)}")
            raise RuntimeError(f"Topic extraction failed: {str(e)}")
    
    def batch_extract_topics(self, texts: List[str]) -> List[str]:
        """
        Extract topics from multiple texts
        
        Args:
            texts: List of input texts/arguments
            
        Returns:
            List of extracted topics
        """
        if not self.initialized:
            self.initialize()
        
        if not texts or not isinstance(texts, list):
            raise ValueError("Texts must be a non-empty list")
        
        results = []
        for text in texts:
            try:
                topic = self.extract_topic(text)
                results.append(topic)
            except Exception as e:
                logger.error(f"Error extracting topic for text '{text[:50]}...': {str(e)}")
                results.append(None)  # Or raise, depending on desired behavior
        
        return results


# Initialize singleton instance
topic_service = TopicService()