File size: 3,447 Bytes
8dc64ab
 
ded69cc
 
 
8dc64ab
 
48ef557
 
ded69cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dc64ab
 
 
 
 
 
 
 
 
 
 
 
 
 
ded69cc
 
795d29b
8dc64ab
 
48ef557
795d29b
8dc64ab
ded69cc
 
 
 
 
 
 
 
 
 
 
 
 
8dc64ab
ded69cc
 
795d29b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any, Union

from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession


Observation = Union[str, dict[str, Any]]
Action = Union[str, dict[str, Any]]  # e.g., user message, tool call schema


@dataclass
class StepResult:
    observation: Observation
    reward: float
    done: bool
    info: dict[str, Any] = field(default_factory=dict)


class WordleEnv:
    """
    Demonstration env. Not a full game; 4-letter variant for brevity.

    Observations are emoji strings; actions are 4-letter lowercase words.
    Reward is 1.0 on success, else 0.0. Terminal on success or after 6 guesses.
    """

    def __init__(self, *, secret: str = "word", max_guesses: int = 6) -> None:
        assert len(secret) == 4 and secret.isalpha()
        self._secret = secret
        self._max = max_guesses
        self._n = 0
        self._obs = "⬜" * 4

    def reset(self) -> Observation:  # noqa: ARG002
        self._n = 0
        self._obs = "⬜" * 4
        return self._obs

    def step(self, action: Action) -> StepResult:
        guess: str = str(action)
        guess = guess.strip().lower()
        if len(guess) != 4 or not guess.isalpha():
            return StepResult(self._obs, -0.05, False, {"error": "invalid guess"})
        self._n += 1
        secret = self._secret
        feedback: list[str] = []
        for i, ch in enumerate(guess):
            if ch == secret[i]:
                feedback.append("🟩")
            elif ch in secret:
                feedback.append("🟨")
            else:
                feedback.append("⬜")
        self._obs = "".join(feedback)
        done = guess == secret or self._n >= self._max
        reward = 1.0 if guess == secret else 0.0
        return StepResult(self._obs, reward, done, {"guesses": self._n})

    def render(self) -> str:
        return self._obs


@dataclass
class SessionContext:
    """Session context with typed dependencies."""

    wordle: WordleEnv


@asynccontextmanager
async def session_lifespan(server: FastMCP) -> AsyncIterator[SessionContext]:
    """Manage session lifecycle with type-safe context."""
    # Initialize on session initialization
    wordle = WordleEnv(secret="word")
    # try-finally if you need cleanup on session termination
    yield SessionContext(wordle=wordle)


# Stateful server (maintains session state)
mcp = FastMCP("StatefulServer", lifespan=session_lifespan)


@mcp.tool()
def step_fn(guess: str, ctx: Context[ServerSession, SessionContext]) -> tuple[str, float, bool, dict]:
    """
    Perform a step in the Wordle environment.

    Args:
        guess (str): The guessed word (4-letter lowercase string).

    Returns:
        tuple[str, float, bool, dict]: A tuple containing:
            - observation: The observation after the step .
            - reward: The reward obtained from the step.
            - done: Whether the game is done.
            - info: Additional info.
    """
    wordle = ctx.request_context.lifespan_context.wordle
    result = wordle.step(guess)
    return result.observation, result.reward, result.done, result.info


# Return an instance of the StreamableHTTP server app
app = mcp.streamable_http_app()


# Run server with streamable_http transport
if __name__ == "__main__":
    mcp.run(transport="streamable-http")