import uuid
from typing import Dict

from fastapi import APIRouter, WebSocket, WebSocketDisconnect
import asyncio
import json

from agent.session import session
import importlib

router = APIRouter()


session_managers: Dict[str, session.SessionManager] = {}


@router.websocket("/ws/{session_id}")
async def websocket_handler(websocket: WebSocket, session_id: str):
    if session_id not in session_managers:
        importlib.reload(session)
        session_managers[session_id] = session.SessionManager(session_id)
        await session_managers[session_id].main_agent.tool_manager.initialize()

    manager = session_managers[session_id]

    await manager.connect(websocket)

    try:
        await manager.send_history_and_current(websocket)

        while True:
            text = await websocket.receive_text()
            try:
                data = json.loads(text)
            except Exception:
                await websocket.send_text(json.dumps({"error": "消息格式错误，必须是JSON"}))
                continue

            if manager.state == "running":
                await websocket.send_text(json.dumps({"error": "任务正在执行中，请稍后再试"}))
                continue

            if data.get("action") == "start":
                content = data.get("content", "默认内容")
                # think = data.get("think", True)
                manager.main_agent.cur_stream_id = uuid.uuid4().hex[:8]
                asyncio.create_task(manager.start(content = content))
            else:
                await websocket.send_text(json.dumps({"error": "不支持的操作"}))

    except WebSocketDisconnect:
        print(f"[WebSocket] 断开: {session_id}")
        manager.disconnect(websocket)