import asyncio
import re
from datetime import datetime
from typing import Any, Awaitable, Callable

from agent.tools.mcp.mcp_proxy import McpProxy
from pathlib import Path
from markdown import markdown
from bs4 import BeautifulSoup
from docx import Document
from weasyprint import HTML
import tempfile

import asyncio
import inspect
from typing import (
    Any,
    Awaitable,
    Callable,
    get_type_hints,
    Union,
    get_origin,
    get_args,
    Literal,
    Annotated,
)

from agent.tools.mcp.mcp_proxy import McpProxy
from crawl4ai import AsyncWebCrawler
from googlesearch import search

class ToolManager:
    def __init__(self, config: dict[str, dict[str, Any]]):
        self.id=99999999
        self.proxy = McpProxy(config)
        self.local_tools: dict[str, dict[str, Any]] = {}
        self.local_tool_handlers: dict[str, Callable[..., Awaitable[Any]]] = {}

    async def initialize(self):
        await self.proxy.initialize()

    def _python_type_to_json_schema(self, typ: type) -> dict[str, Any]:
        origin = get_origin(typ)
        args = get_args(typ)

        if origin is Union:
            non_none_args = [arg for arg in args if arg is not type(None)]
            if len(non_none_args) == 1:
                return self._python_type_to_json_schema(non_none_args[0])
            else:
                return {"anyOf": [self._python_type_to_json_schema(arg) for arg in non_none_args]}

        elif typ is str:
            return {"type": "string"}
        elif typ in [int, float]:
            return {"type": "number"}
        elif typ is bool:
            return {"type": "boolean"}
        elif origin in [list, tuple]:
            return {
                "type": "array",
                "items": self._python_type_to_json_schema(args[0]) if args else {"type": "string"}
            }
        elif origin is dict:
            if len(args) == 2 and args[0] is str:
                return {
                    "type": "object",
                    "additionalProperties": self._python_type_to_json_schema(args[1])
                }
            else:
                return {"type": "object"}
        elif origin is Literal:
            return {"enum": list(args)}
        elif origin is Annotated:
            return self._python_type_to_json_schema(args[0])

        return {"type": "string"}

    def register_tool(
        self,
        name: str,
        description: str,
        input_schema: dict[str, Any],
        handler: Callable[..., Awaitable[Any]]
    ):
        self.local_tools[name] = {
            "type": "function",
            "function": {
                "name": name,
                "description": description,
                "parameters": input_schema,
            }
        }
        self.local_tool_handlers[name] = handler

    def tool(self, func: Callable[..., Awaitable[Any]]):
        sig = inspect.signature(func)
        type_hints = get_type_hints(func)
        properties = {}
        required = []

        for name, param in sig.parameters.items():
            hint = type_hints.get(name, str)
            schema = self._python_type_to_json_schema(hint)
            properties[name] = {
                "type": schema.get("type", "string"),
                **schema,
                "description": f"{name}",
            }
            if param.default is inspect.Parameter.empty:
                required.append(name)

        input_schema = {
            "type": "object",
            "properties": properties,
            "required": required,
        }

        self.register_tool(
            name=func.__name__,
            description=func.__doc__ or "无描述",
            input_schema=input_schema,
            handler=func
        )
        return func

    async def list_tools(self) -> list[dict[str, Any]]:
        mcp_tools = await self.proxy.list_tools()
        return mcp_tools + list(self.local_tools.values())

    async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
        if tool_name in self.local_tool_handlers:
            return await self.local_tool_handlers[tool_name](**arguments)
        return await self.proxy.call_tool(tool_name, arguments)

    async def refresh_tools(self):
        await self.proxy.refresh_tools()

    async def cleanup(self):
        await self.proxy.cleanup()





from datetime import datetime

def register_medical_tools(tool_manager):

    @tool_manager.tool
    async def get_patient_id_by_name(patient_name: str) -> str:
        """通过姓名查询患者唯一ID"""
        name_to_id = {
            "张三": "P001",
            "李四": "P002"
        }
        return name_to_id.get(patient_name, f"未找到名为 {patient_name} 的患者ID")

    @tool_manager.tool
    async def get_patient_profile(patient_id: str) -> str:
        """根据患者ID返回基本信息"""
        patient_profiles = {
            "P001": {"姓名": "张三", "年龄": 35, "性别": "男", "病史": "高血压"},
            "P002": {"姓名": "李四", "年龄": 28, "性别": "女", "病史": "无"},
        }
        profile = patient_profiles.get(patient_id)
        if not profile:
            return f"未找到患者ID {patient_id} 的资料"
        return f"患者 {profile['姓名']}，年龄 {profile['年龄']} 岁，性别 {profile['性别']}，病史：{profile['病史']}"

    @tool_manager.tool
    async def get_visit_history(patient_id: str) -> str:
        """查询患者的就诊记录"""
        visits = {
            "P001": [
                {"日期": "2025-04-01", "科室": "内科", "摘要": "血压偏高"},
                {"日期": "2025-05-12", "科室": "心内科", "摘要": "建议做动态心电图"},
            ],
            "P002": [
                {"日期": "2025-03-20", "科室": "妇科", "摘要": "常规体检"},
            ]
        }
        history = visits.get(patient_id)
        if not history:
            return f"未找到患者ID {patient_id} 的就诊记录"
        return "\n".join([f"{v['日期']} {v['科室']}：{v['摘要']}" for v in history])

    @tool_manager.tool
    async def get_recent_lab_results(patient_id: str) -> str:
        """根据患者ID返回最近一次化验结果"""
        lab_results = {
            "P001": {"血糖": "6.2 mmol/L", "胆固醇": "5.8 mmol/L"},
            "P002": {"血糖": "4.9 mmol/L", "胆固醇": "4.3 mmol/L"},
        }
        results = lab_results.get(patient_id)
        if not results:
            return f"未找到患者ID {patient_id} 的化验结果"
        return f"血糖：{results['血糖']}，胆固醇：{results['胆固醇']}"

    @tool_manager.tool
    async def get_doctor_notes(patient_id: str) -> str:
        """返回医生最近的诊疗记录"""
        notes = {
            "P001": "建议控制饮食，定期测血压。",
            "P002": "无异常，继续观察。",
        }
        return notes.get(patient_id, f"未找到患者ID {patient_id} 的医生记录")

    @tool_manager.tool
    async def get_medication_list(patient_id: str) -> str:
        """返回患者当前用药列表"""
        medications = {
            "P001": ["降压药A", "降压药B"],
            "P002": ["维生素C"],
        }
        meds = medications.get(patient_id)
        if not meds:
            return f"{patient_id} 当前无记录用药"
        return f"当前用药：{', '.join(meds)}"

    @tool_manager.tool
    async def schedule_appointment(patient_id: str, department: str, date: str) -> str:
        """为患者安排指定科室的预约"""
        try:
            datetime.strptime(date, "%Y-%m-%d")
        except ValueError:
            return "日期格式错误，请使用 YYYY-MM-DD 格式"

        return f"已为患者 {patient_id} 安排在 {date} 到 {department} 科室就诊"


def register_tools(tool_manager):
    @tool_manager.tool
    async def execute_sql(sql_list: list[str]) -> str:
        """
        并发执行传入的多条 SQL 语句，每条语句使用独立连接，避免并发冲突。
        """
        import aiomysql
        import asyncio

        if not sql_list or not isinstance(sql_list, list):
            return "[无有效 SQL 语句列表]"

        async def run_single_sql(sql_stmt: str):
            try:
                # 每个 SQL 独立连接
                conn = await aiomysql.connect(
                    host='10.10.0.2',
                    port=6606,
                    user='root',
                    password='yiPmJHHlRa',
                    db='crm',
                    autocommit=True
                )
                async with conn.cursor() as cur:
                    await cur.execute(sql_stmt)
                    result = await cur.fetchall()
                    columns = [desc[0] for desc in cur.description] if cur.description else []

                    await conn.ensure_closed()

                    if not result:
                        return f"[SQL 执行成功] 无结果: {sql_stmt}"

                    formatted = "\n".join(
                        [", ".join(columns)] +
                        [", ".join(str(item) for item in row) for row in result]
                    )
                    return f"[SQL 执行成功] {sql_stmt}\n{formatted}"

            except Exception as e:
                return f"[SQL 执行错误] {sql_stmt}\n错误: {e}"

        # 并发执行
        results = await asyncio.gather(
            *(run_single_sql(stmt.strip()) for stmt in sql_list if stmt.strip())
        )

        return "\n\n".join(results)
    # @tool_manager.tool
    async def show_rendered_element(code: str) -> str:
        """
        用于向用户展示一个html元素，你可以使用tailwind
        例如：
        <div class="p-6 bg-white dark:bg-gray-900 text-gray-900 dark:text-gray-100 rounded-lg shadow-md">
          <h1 class="text-3xl font-bold mb-4">欢迎使用 Tailwind CSS</h1>
        </div>
        """
        # 可选：这里进行代码安全检查或过滤
        return '用户已经成功看到了渲染过后的元素'

    @tool_manager.tool
    async def show_rendered_chart(option: str ) :
        """
            功能说明：
            接收用户传入的 ECharts 配置项（合法的 YAML 格式字符串），
            解析为 dict 后用于前端渲染任意类型的 ECharts 图表。

            注意：
            1. 不同图表类型的配置结构差异较大，必须严格按照 ECharts 官方文档对应图表类型的配置规范来填写。
            2. 参数中必须包含至少一个 series，且 series 中的每个元素 type 字段需对应 ECharts 支持的图表类型。
            3. 根据 series 中 type 的不同，必须校验对应配置项的必填字段和数据格式：
               - **bar/line/scatter** 等类目轴图：需配置有效的 xAxis、yAxis，series.data 为数组。
               - **pie/funnel**：无坐标轴，series.data 需为数组对象，每个对象需有 name 和 value。
               - **radar**：需配置 radar 指标，series.data 为多维数组。
               - **gauge**：需包含指针和刻度配置，series.data 通常为单个数值。
               - 其他图表请参考 [ECharts 官方文档](https://echarts.apache.org/zh/option.html)。

            4. 若配置不符合规范，返回详细的错误提示，指出具体缺失或格式错误的字段。
            5. 解析 YAML 时请确保格式正确，否则提示格式错误。
            6. 所有颜色值（如 color 字段）必须是字符串形式，必须用引号包裹，例如 "#5470c6"、"rgba(255, 0, 0, 0.5)" 等，避免被 YAML 解析器识别为注释。
            参数：
            - option: str，ECharts 配置的 YAML 格式字符串。

            返回：
            - 成功时返回提示“图表已展示”。
            - 出错时返回错误信息字符串。

            示例（柱状图和折线图混合）：
            title:
            text: 示例多图表
            tooltip: {}
            xAxis:
            type: category
            data: ["A", "B", "C"]
            yAxis:
            type: value
            series:

            type: bar
            data: [120, 200, 150]

            type: line
            data: [100, 180, 90]
        """

        import json
        import yaml

        try:
            if isinstance(option, str):
                option = yaml.safe_load(option)
        except Exception as e:
            return {
                "success":False,
                "error":f"参数 option YAML 解析失败: {str(e)}"
            }

        if not isinstance(option, dict):
            return {
                "success":False,
                "error":"参数格式错误：option 必须是 dict 或合法的 YAML 字符串"
            }

        # 这里你可以根据需要继续校验 option 是否包含 series 等

        return {
            "success": True,
            "message": "Echart已经成功向用户展示",
            "option": json.dumps(option, ensure_ascii=False)
        }

    @tool_manager.tool
    async def search_and_extract(query: str, top_k: int = 5) -> list[dict]:
        """
        搜索关键词并提取清洗后的网页正文，返回格式：
        [{"url": "xxx", "text": "正文内容"}, ...]

        参数：
        - query: 搜索关键词
        - top_k: 返回前 N 个搜索结果
        """

        def clean_markdown(md_text: str) -> str:
            """清理 Markdown 内容，仅保留有效中文正文"""
            md_text = re.sub(r'!\[.*?\]\(.*?\)', '', md_text)
            md_text = re.sub(r'\[.*?\]', '', md_text)
            md_text = re.sub(r'\(.*?\)', '', md_text)
            md_text = re.sub(r'[ \t]+', '', md_text)
            md_text = re.sub(r'\n\s*\n+', '', md_text)
            md_text = re.sub(r'\n', '', md_text)
            md_text = re.sub(r'[*#]', '', md_text)
            md_text = re.sub(r'^\s*[\*\-+]\s+', '', md_text, flags=re.MULTILINE)
            md_text = md_text.strip()

            chinese_chars = re.findall(r'[\u4e00-\u9fff]', md_text)
            ratio = len(chinese_chars) / len(md_text) if md_text else 0
            return md_text if ratio >= 0.3 else ""

        async def fetch_and_clean(url: str) -> dict:
            try:
                async with AsyncWebCrawler() as crawler:
                    result = await crawler.arun(url=url, strategy="auto")
                    markdown = result.markdown if result else ""
                    clean_text = clean_markdown(markdown)
                    return {"url": url, "text": clean_text or "[正文内容为空]"}
            except Exception as e:
                return {"url": url, "text": f"[抓取失败] {e}"}

        try:
            urls = list(search(query, num_results=top_k))
        except Exception as e:
            # 无 URL 可抓取，只能返回搜索失败信息
            return [{"url": "", "text": f"[搜索失败] {e}"}]

        tasks = [fetch_and_clean(url) for url in urls]
        results = await asyncio.gather(*tasks)
        return results

    @tool_manager.tool
    async def save_markdown_file(content: str, filename: str = "output", format: str = "pdf") -> dict:
        """
        📄 将 Markdown 文本保存为 PDF 或 DOCX 文件。

        🧠【用途说明】：
        - 当任务涉及 **论文撰写**、**总结报告**、**生成文档内容** 等场景时，应调用此工具将 markdown 内容转换为文档格式进行保存；
        - 可用于将系统生成的内容导出为 PDF 或 Word 文件，以供后续查看或下载。

        ⚙️【参数说明】：
        - content: markdown 格式的文本内容；
        - filename: 保存的文件名（不带后缀）；
        - format: 保存格式，支持 'pdf' 或 'docx'。

        ✅【返回结构】：
        {
            "path": "完整文件路径",
            "message": "成功或失败信息"
        }

        🗂️【存储位置】：固定保存至 D:/tmp/markdown_outputs 目录下。
        """
        try:
            # 修改后的存储目录
            output_dir = Path("D:/tmp/markdown_outputs")
            output_dir.mkdir(parents=True, exist_ok=True)

            file_path = output_dir / f"{filename}.{format}"

            # 将 Markdown 转为 HTML
            html_content = markdown(content)

            if format == "pdf":
                # 使用 weasyprint 转为 PDF
                HTML(string=html_content).write_pdf(str(file_path))
            elif format == "docx":
                # 使用 python-docx 构建简单文档
                soup = BeautifulSoup(html_content, "html.parser")
                doc = Document()
                for elem in soup.find_all(['p', 'h1', 'h2', 'h3', 'li']):
                    text = elem.get_text(strip=True)
                    if text:
                        doc.add_paragraph(text)
                doc.save(str(file_path))
            else:
                return {"path": "", "message": f"[失败] 不支持的格式: {format}"}

            return {"path": str(file_path), "message": "[成功] 文件已保存"}

        except Exception as e:
            return {"path": "", "message": f"[失败] {e}"}


CONFIG = {
    # "FastGPT": {
    #     "url": "https://cloud.fastgpt.cn/api/mcp/app/oj7BgLJmYh45J17YxC7EEFAj/mcp"
    # },
    "filesystem": {
        "args": [
            "-y",
            "@modelcontextprotocol/server-filesystem",
            "/"
        ],
        "command": "npx"
    }
}

tool_manager = ToolManager(CONFIG)

async def main():
    register_medical_tools(tool_manager)
    # await tool_manager.initialize()

    tools = await tool_manager.list_tools()
    print(tools)




if __name__ == '__main__':
    asyncio.run(main())
