# mcp_client/http_client.py

import asyncio
import logging
from contextlib import AsyncExitStack
from typing import Any

from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from .base import BaseMcpClient


import asyncio
from typing import Any

from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client

from .base import BaseMcpClient


class HttpMcpClient(BaseMcpClient):
    def __init__(self, name: str, config: dict[str, Any]) -> None:
        super().__init__(name, config)
        self._cleanup_lock = asyncio.Lock()
        self._stream_ctx = None
        self._session_ctx = None
        self._stream = None
        self.session: ClientSession | None = None

    async def initialize(self) -> None:
        if self.session:
            return  # 已初始化，跳过

        url = self.config.get("url")
        if not url:
            raise ValueError("Missing 'url' in config.")

        # 使用 AsyncExitStack 来统一管理异步上下文
        self._exit_stack = AsyncExitStack()
        stream_ctx = streamablehttp_client(url)
        stream = await self._exit_stack.enter_async_context(stream_ctx)
        read_stream, write_stream, _ = stream

        session_ctx = ClientSession(read_stream, write_stream)
        self.session = await self._exit_stack.enter_async_context(session_ctx)
        await self.session.initialize()

    async def list_tools(self) -> list[Any]:
        if not self.session:
            raise RuntimeError(f"Client {self.name} not initialized")

        tools_response = await self.session.list_tools()
        tools = []

        for item in tools_response:
            if isinstance(item, tuple) and item[0] == "tools":
                tools.extend(item[1])

        return tools

    async def call_tool(
            self,
            tool_name: str,
            arguments: dict[str, Any],
            retries: int = 2,
            delay: float = 1.0,
    ) -> Any:
        url = self.config.get("url")
        if not url:
            raise ValueError("Missing 'url' in config.")

        for attempt in range(1, retries + 1):
            try:
                async with streamablehttp_client(url) as (read_stream, write_stream, _):
                    async with ClientSession(read_stream, write_stream) as session:
                        await session.initialize()
                        return (await session.call_tool(tool_name, arguments)).content[0].text
            except Exception as e:
                if attempt < retries:
                    await asyncio.sleep(delay)
                else:
                    raise RuntimeError(f"Tool call failed after {retries} attempts: {e}") from e

    async def cleanup(self) -> None:
        async with self._cleanup_lock:
            if self._exit_stack is not None:
                await self._exit_stack.aclose()
                self._exit_stack = None
                self.session = None
