main
1"""Lightweight connection handling for MCP servers."""
2
3from abc import ABC, abstractmethod
4from contextlib import AsyncExitStack
5from typing import Any
6
7from mcp import ClientSession, StdioServerParameters
8from mcp.client.sse import sse_client
9from mcp.client.stdio import stdio_client
10from mcp.client.streamable_http import streamablehttp_client
11
12
13class MCPConnection(ABC):
14 """Base class for MCP server connections."""
15
16 def __init__(self):
17 self.session = None
18 self._stack = None
19
20 @abstractmethod
21 def _create_context(self):
22 """Create the connection context based on connection type."""
23
24 async def __aenter__(self):
25 """Initialize MCP server connection."""
26 self._stack = AsyncExitStack()
27 await self._stack.__aenter__()
28
29 try:
30 ctx = self._create_context()
31 result = await self._stack.enter_async_context(ctx)
32
33 if len(result) == 2:
34 read, write = result
35 elif len(result) == 3:
36 read, write, _ = result
37 else:
38 raise ValueError(f"Unexpected context result: {result}")
39
40 session_ctx = ClientSession(read, write)
41 self.session = await self._stack.enter_async_context(session_ctx)
42 await self.session.initialize()
43 return self
44 except BaseException:
45 await self._stack.__aexit__(None, None, None)
46 raise
47
48 async def __aexit__(self, exc_type, exc_val, exc_tb):
49 """Clean up MCP server connection resources."""
50 if self._stack:
51 await self._stack.__aexit__(exc_type, exc_val, exc_tb)
52 self.session = None
53 self._stack = None
54
55 async def list_tools(self) -> list[dict[str, Any]]:
56 """Retrieve available tools from the MCP server."""
57 response = await self.session.list_tools()
58 return [
59 {
60 "name": tool.name,
61 "description": tool.description,
62 "input_schema": tool.inputSchema,
63 }
64 for tool in response.tools
65 ]
66
67 async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
68 """Call a tool on the MCP server with provided arguments."""
69 result = await self.session.call_tool(tool_name, arguments=arguments)
70 return result.content
71
72
73class MCPConnectionStdio(MCPConnection):
74 """MCP connection using standard input/output."""
75
76 def __init__(self, command: str, args: list[str] = None, env: dict[str, str] = None):
77 super().__init__()
78 self.command = command
79 self.args = args or []
80 self.env = env
81
82 def _create_context(self):
83 return stdio_client(
84 StdioServerParameters(command=self.command, args=self.args, env=self.env)
85 )
86
87
88class MCPConnectionSSE(MCPConnection):
89 """MCP connection using Server-Sent Events."""
90
91 def __init__(self, url: str, headers: dict[str, str] = None):
92 super().__init__()
93 self.url = url
94 self.headers = headers or {}
95
96 def _create_context(self):
97 return sse_client(url=self.url, headers=self.headers)
98
99
100class MCPConnectionHTTP(MCPConnection):
101 """MCP connection using Streamable HTTP."""
102
103 def __init__(self, url: str, headers: dict[str, str] = None):
104 super().__init__()
105 self.url = url
106 self.headers = headers or {}
107
108 def _create_context(self):
109 return streamablehttp_client(url=self.url, headers=self.headers)
110
111
112def create_connection(
113 transport: str,
114 command: str = None,
115 args: list[str] = None,
116 env: dict[str, str] = None,
117 url: str = None,
118 headers: dict[str, str] = None,
119) -> MCPConnection:
120 """Factory function to create the appropriate MCP connection.
121
122 Args:
123 transport: Connection type ("stdio", "sse", or "http")
124 command: Command to run (stdio only)
125 args: Command arguments (stdio only)
126 env: Environment variables (stdio only)
127 url: Server URL (sse and http only)
128 headers: HTTP headers (sse and http only)
129
130 Returns:
131 MCPConnection instance
132 """
133 transport = transport.lower()
134
135 if transport == "stdio":
136 if not command:
137 raise ValueError("Command is required for stdio transport")
138 return MCPConnectionStdio(command=command, args=args, env=env)
139
140 elif transport == "sse":
141 if not url:
142 raise ValueError("URL is required for sse transport")
143 return MCPConnectionSSE(url=url, headers=headers)
144
145 elif transport in ["http", "streamable_http", "streamable-http"]:
146 if not url:
147 raise ValueError("URL is required for http transport")
148 return MCPConnectionHTTP(url=url, headers=headers)
149
150 else:
151 raise ValueError(f"Unsupported transport type: {transport}. Use 'stdio', 'sse', or 'http'")