main
  1"""Lightweight connection handling for MCP servers."""
  2
  3from abc import ABC, abstractmethod
  4from contextlib import AsyncExitStack
  5from typing import Any
  6
  7from mcp import ClientSession, StdioServerParameters
  8from mcp.client.sse import sse_client
  9from mcp.client.stdio import stdio_client
 10from mcp.client.streamable_http import streamablehttp_client
 11
 12
 13class MCPConnection(ABC):
 14    """Base class for MCP server connections."""
 15
 16    def __init__(self):
 17        self.session = None
 18        self._stack = None
 19
 20    @abstractmethod
 21    def _create_context(self):
 22        """Create the connection context based on connection type."""
 23
 24    async def __aenter__(self):
 25        """Initialize MCP server connection."""
 26        self._stack = AsyncExitStack()
 27        await self._stack.__aenter__()
 28
 29        try:
 30            ctx = self._create_context()
 31            result = await self._stack.enter_async_context(ctx)
 32
 33            if len(result) == 2:
 34                read, write = result
 35            elif len(result) == 3:
 36                read, write, _ = result
 37            else:
 38                raise ValueError(f"Unexpected context result: {result}")
 39
 40            session_ctx = ClientSession(read, write)
 41            self.session = await self._stack.enter_async_context(session_ctx)
 42            await self.session.initialize()
 43            return self
 44        except BaseException:
 45            await self._stack.__aexit__(None, None, None)
 46            raise
 47
 48    async def __aexit__(self, exc_type, exc_val, exc_tb):
 49        """Clean up MCP server connection resources."""
 50        if self._stack:
 51            await self._stack.__aexit__(exc_type, exc_val, exc_tb)
 52        self.session = None
 53        self._stack = None
 54
 55    async def list_tools(self) -> list[dict[str, Any]]:
 56        """Retrieve available tools from the MCP server."""
 57        response = await self.session.list_tools()
 58        return [
 59            {
 60                "name": tool.name,
 61                "description": tool.description,
 62                "input_schema": tool.inputSchema,
 63            }
 64            for tool in response.tools
 65        ]
 66
 67    async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
 68        """Call a tool on the MCP server with provided arguments."""
 69        result = await self.session.call_tool(tool_name, arguments=arguments)
 70        return result.content
 71
 72
 73class MCPConnectionStdio(MCPConnection):
 74    """MCP connection using standard input/output."""
 75
 76    def __init__(self, command: str, args: list[str] = None, env: dict[str, str] = None):
 77        super().__init__()
 78        self.command = command
 79        self.args = args or []
 80        self.env = env
 81
 82    def _create_context(self):
 83        return stdio_client(
 84            StdioServerParameters(command=self.command, args=self.args, env=self.env)
 85        )
 86
 87
 88class MCPConnectionSSE(MCPConnection):
 89    """MCP connection using Server-Sent Events."""
 90
 91    def __init__(self, url: str, headers: dict[str, str] = None):
 92        super().__init__()
 93        self.url = url
 94        self.headers = headers or {}
 95
 96    def _create_context(self):
 97        return sse_client(url=self.url, headers=self.headers)
 98
 99
100class MCPConnectionHTTP(MCPConnection):
101    """MCP connection using Streamable HTTP."""
102
103    def __init__(self, url: str, headers: dict[str, str] = None):
104        super().__init__()
105        self.url = url
106        self.headers = headers or {}
107
108    def _create_context(self):
109        return streamablehttp_client(url=self.url, headers=self.headers)
110
111
112def create_connection(
113    transport: str,
114    command: str = None,
115    args: list[str] = None,
116    env: dict[str, str] = None,
117    url: str = None,
118    headers: dict[str, str] = None,
119) -> MCPConnection:
120    """Factory function to create the appropriate MCP connection.
121
122    Args:
123        transport: Connection type ("stdio", "sse", or "http")
124        command: Command to run (stdio only)
125        args: Command arguments (stdio only)
126        env: Environment variables (stdio only)
127        url: Server URL (sse and http only)
128        headers: HTTP headers (sse and http only)
129
130    Returns:
131        MCPConnection instance
132    """
133    transport = transport.lower()
134
135    if transport == "stdio":
136        if not command:
137            raise ValueError("Command is required for stdio transport")
138        return MCPConnectionStdio(command=command, args=args, env=env)
139
140    elif transport == "sse":
141        if not url:
142            raise ValueError("URL is required for sse transport")
143        return MCPConnectionSSE(url=url, headers=headers)
144
145    elif transport in ["http", "streamable_http", "streamable-http"]:
146        if not url:
147            raise ValueError("URL is required for http transport")
148        return MCPConnectionHTTP(url=url, headers=headers)
149
150    else:
151        raise ValueError(f"Unsupported transport type: {transport}. Use 'stdio', 'sse', or 'http'")