File size: 5,444 Bytes
4aec76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from typing import List
import json

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_community.chat_message_histories import ChatMessageHistory

from logger import get_logger
from llm_system import config

log = get_logger(name="core_history")


class RedisChatMessageHistory(BaseChatMessageHistory):
    """A Redis-backed chat history implementation.

    Stores messages in a Redis list keyed by `chat_history:{session_id}` as JSON
    objects with fields: `role` ("human"|"ai"), `content`, `ts`.
    This implementation lazily imports `redis` so the module can be imported
    in environments where `redis` is not installed.
    """

    def __init__(self, session_id: str, redis_url: 'Optional[str]' = None, ttl_seconds: int = 0):
        try:
            import redis
        except Exception:
            log.error("Redis package not available. Install `redis` to use Redis history backend.")
            raise

        self._redis = redis.from_url(redis_url) if redis_url else redis.Redis()
        self.session_id = session_id
        self.key = f"chat_history:{session_id}"
        self.ttl_seconds = ttl_seconds  # 0 = no expiry
        # Try a quick ping to validate the connection and fail fast if Redis is unreachable
        try:
            self._redis.ping()
        except Exception as e:
            log.error(f"Unable to connect to Redis at {redis_url}: {e}")
            raise

    @property
    def messages(self) -> List[BaseMessage]:
        """Return the list of messages for this session as BaseMessage objects."""
        raw = self._redis.lrange(self.key, 0, -1)
        msgs: List[BaseMessage] = []
        for item in raw:
            try:
                obj = json.loads(item)
                role = obj.get("role")
                content = obj.get("content", "")
                if role == "ai":
                    msgs.append(AIMessage(content=content))
                else:
                    msgs.append(HumanMessage(content=content))
            except Exception:
                # skip malformed entries
                continue
        return msgs

    def add_message(self, message: BaseMessage) -> None:
        """Append a message to the Redis list."""
        role = "ai" if getattr(message, "type", None) == "ai" or message.__class__.__name__.lower().startswith("aimessage") else "human"
        payload = {"role": role, "content": getattr(message, "content", str(message)), "ts": None}
        self._redis.rpush(self.key, json.dumps(payload))
        if self.ttl_seconds > 0:
            self._redis.expire(self.key, self.ttl_seconds)

    def add_messages(self, messages: List[BaseMessage]) -> None:
        if not messages:
            return
        pipe = self._redis.pipeline()
        for message in messages:
            role = "ai" if getattr(message, "type", None) == "ai" or message.__class__.__name__.lower().startswith("aimessage") else "human"
            payload = {"role": role, "content": getattr(message, "content", str(message)), "ts": None}
            pipe.rpush(self.key, json.dumps(payload))
        if self.ttl_seconds > 0:
            pipe.expire(self.key, self.ttl_seconds)
        pipe.execute()

    def clear(self) -> None:
        self._redis.delete(self.key)


class HistoryStore:
    """A class to manage chat message histories for different sessions/users.

    This store can be backed by in-memory `ChatMessageHistory` (default) or
    by `RedisChatMessageHistory` when `llm_system.config.HISTORY_BACKEND == 'redis'`.
    """

    def __init__(self):
        self.histories = {}
        self.backend = getattr(config, "HISTORY_BACKEND", "memory")
        self.redis_url = getattr(config, "REDIS_URL", None)
        log.info(f"Initialized HistoryStore (backend={self.backend}).")

    def get_session_history(self, session_id: str) -> BaseChatMessageHistory:
        """Retrieve or create the chat history for `session_id`.

        Returns a `BaseChatMessageHistory` implementation appropriate for the
        configured backend.
        """

        if session_id in self.histories:
            log.info(f"Retrieved existing history for session: `{session_id}`")
            return self.histories[session_id]

        # Create a new history according to backend
        if self.backend == "redis":
            try:
                ttl = getattr(config, "REDIS_HISTORY_TTL_SECONDS", 0)
                hist = RedisChatMessageHistory(session_id=session_id, redis_url=self.redis_url, ttl_seconds=ttl)
            except Exception:
                log.exception("Failed to initialize RedisChatMessageHistory, falling back to in-memory.")
                hist = ChatMessageHistory()
        else:
            hist = ChatMessageHistory()

        self.histories[session_id] = hist
        log.info(f"Created history for session: `{session_id}` (backend={self.backend})")
        return hist

    def clear_session_history(self, session_id: str):
        if session_id in self.histories:
            try:
                self.histories[session_id].clear()
            except Exception:
                pass
            del self.histories[session_id]
            log.info(f"Cleared history for session: `{session_id}`")
            return True
        else:
            log.warning(f"No history found for session: `{session_id}` to clear.")
            return False