# mcp_client/base.py
from abc import ABC, abstractmethod
from typing import Any

from abc import ABC, abstractmethod
from typing import Any

from abc import ABC, abstractmethod
from typing import Any


class BaseMcpClient(ABC):
    def __init__(self, name: str, config: dict[str, Any]) -> None:
        self.name = name
        self.config = config
        self._tools_cache: list[Any] | None = None  # 缓存工具列表

    async def __aenter__(self) -> "BaseMcpClient":
        await self.initialize()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
        await self.cleanup()

    @abstractmethod
    async def initialize(self) -> None:
        ...

    @abstractmethod
    async def list_tools(self) -> list[Any]:
        ...

    @abstractmethod
    async def call_tool(
            self,
            tool_name: str,
            arguments: dict[str, Any],
            retries: int = 2,
            delay: float = 1.0,
    ) -> Any:
        ...

    @abstractmethod
    async def cleanup(self) -> None:
        ...

    async def tools(self) -> list[Any]:
        """统一的工具列表访问接口，带缓存"""
        if self._tools_cache is None:
            tools_response = await self.list_tools()

            tools = []

            for tool in tools_response:
                tools.append(
                    Tool(tool.name, tool.description, tool.inputSchema)

                )
            self._tools_cache = tools
        return self._tools_cache

    def clear_tool_cache(self) -> None:
        """手动清除工具缓存"""
        self._tools_cache = None


from .stdio_client import StdioMcpClient
from .http_client import HttpMcpClient
from .base import BaseMcpClient


def create_mcp_client(name: str, config: dict[str, Any]) -> BaseMcpClient:
    if "url" in config:
        return HttpMcpClient(name, config)
    elif "command" in config and "args" in config:
        return StdioMcpClient(name, config)
    else:
        raise ValueError(f"Invalid config for MCP client '{name}': {config}")


class Tool:
    """Represents a tools with its properties and formatting."""

    def __init__(
            self, name: str, description: str, input_schema: dict[str, Any]
    ) -> None:
        self.name: str = name
        self.description: str = description
        self.input_schema: dict[str, Any] = input_schema

    def format_for_llm(self) -> str:
        """Format tools information for LLM.

        Returns:
            A formatted string describing the tools.
        """
        args_desc = []
        if "properties" in self.input_schema:
            for param_name, param_info in self.input_schema["properties"].items():
                arg_desc = (
                    f"- {param_name}: {param_info.get('description', 'No description')}"
                )
                if param_name in self.input_schema.get("required", []):
                    arg_desc += " (required)"
                args_desc.append(arg_desc)

        return f"""
Tool: {self.name}
Description: {self.description}
Arguments:
{chr(10).join(args_desc)}
"""

    def to_dict(self) -> dict[str, Any]:
        """Serialize tools to OpenAI function calling schema."""
        return {
            "type": "function",
            "function": {
                "name": self.name,
                "description": self.description,
                "parameters": self.input_schema
            }
        }



class MCPClientProxy:

    def __init__(self,tools):
        self.tools = tools


    def list_tools(self):

        tools = []
        for tool in self.tools:
            tools.append({
                "type":'function',
                "function":{
                    "name":tool['name'],
                    "description":tool['description'],
                    "parameters":tool['inputSchema']
                }
            })
        return tools

    async def call_tool(self,tool_name,arguments):
        from HotCode.core.globals import tool_set, tool_set_lock
        for tool in self.tools:
            if tool['name'] == tool_name:
                async with tool_set_lock:
                    tool_client = tool_set.get(tool['toolSetId'])
                    if not tool_client:
                        raise Exception(f"{tool_name}  所属的工具集{tool['toolSetId']}不存在或未发布")
                    async with tool_client as client:
                        return await client.call_tool(tool_name,arguments)
        return None